Commit d5cd9655 authored by Leo Gao's avatar Leo Gao
Browse files

Implement caching

parent 1815286c
......@@ -235,9 +235,66 @@ def perplexity(items):
return math.exp(-mean(items))
req_ret_lens = {
'loglikelihood': 2
'loglikelihood': 2,
}
import os
import json
import hashlib
from sqlitedict import SqliteDict
def hash_args(args):
dat = b""
for arg in args:
assert isinstance(arg, str) or isinstance(arg, int)
dat += str(arg).encode()
dat += b"\0"
return hashlib.sha256(dat).hexdigest()
class CachingLM:
def __init__(self, lm, cache_db):
self.lm = lm
self.cache_db = cache_db
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
def __getattr__(self, attr):
def fn(requests):
res = []
remaining_reqs = []
# figure out which ones are cached and which ones are new
for req in requests:
hsh = attr + '_' + hash_args(req)
if hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None: resptr += 1
res[resptr] = r
# caching
hsh = attr + '_' + hash_args(req)
self.dbdict[hsh] = r
return res
return fn
class Request:
def __init__(self, type, args, index=None):
......
......@@ -5,7 +5,7 @@ import random
import itertools
import collections
from lm_eval import models, tasks, evaluator
from lm_eval import models, tasks, evaluator, base
def parse_args():
......@@ -18,14 +18,19 @@ def parse_args():
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--cache', action="store_true")
return parser.parse_args()
def main():
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment