Commit 1c0ff968 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add CachingLM back

parent 9dea125b
import abc import abc
import os
from typing import Union from typing import Union
from sqlitedict import SqliteDict
import json
import hashlib
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger
class LM(abc.ABC): class LM(abc.ABC):
...@@ -12,6 +17,7 @@ class LM(abc.ABC): ...@@ -12,6 +17,7 @@ class LM(abc.ABC):
(inputs/outputs should be tokenization-agnostic.) (inputs/outputs should be tokenization-agnostic.)
""" """
self.cache_hook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -118,3 +124,101 @@ class LM(abc.ABC): ...@@ -118,3 +124,101 @@ class LM(abc.ABC):
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return 1 return 1
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
### SQLite-based caching of LM responses
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db):
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests):
res = []
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
for req in requests:
hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed."
)
warned = True
res.append(None)
remaining_reqs.append(req)
elif hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None:
resptr += 1
res[resptr] = r
# caching
hsh = hash_args(attr, req.args)
self.dbdict[hsh] = r
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment