Commit 77d4b087 authored by Leo Gao's avatar Leo Gao
Browse files

Fix caching

parent 1b467c57
...@@ -275,13 +275,9 @@ import json ...@@ -275,13 +275,9 @@ import json
import hashlib import hashlib
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
def hash_args(args): def hash_args(attr, args):
dat = b"" dat = json.dumps([attr] + list(args))
for arg in args: return hashlib.sha256(dat.encode('utf-8')).hexdigest()
assert isinstance(arg, str) or isinstance(arg, int)
dat += str(arg).encode()
dat += b"\0"
return hashlib.sha256(dat).hexdigest()
class CachingLM: class CachingLM:
...@@ -298,7 +294,7 @@ class CachingLM: ...@@ -298,7 +294,7 @@ class CachingLM:
# figure out which ones are cached and which ones are new # figure out which ones are cached and which ones are new
for req in requests: for req in requests:
hsh = attr + '_' + hash_args(req) hsh = hash_args(attr, req)
if hsh in self.dbdict: if hsh in self.dbdict:
ob = self.dbdict[hsh] ob = self.dbdict[hsh]
...@@ -320,9 +316,9 @@ class CachingLM: ...@@ -320,9 +316,9 @@ class CachingLM:
res[resptr] = r res[resptr] = r
# caching # caching
hsh = attr + '_' + hash_args(req) hsh = hash_args(attr, req)
self.dbdict[hsh] = r self.dbdict[hsh] = r
self.dbdict.commit()
return res return res
return fn return fn
...@@ -344,6 +340,9 @@ class Request: ...@@ -344,6 +340,9 @@ class Request:
def __getitem__(self, i): def __getitem__(self, i):
return Request(self.type, self.args, i) return Request(self.type, self.args, i)
def __eq__(self, other):
return self.type == other.type and self.args == other.args and self.index == other.index
class RequestFactory: class RequestFactory:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment