Commit efbe6e7f authored by Leo Gao's avatar Leo Gao
Browse files

Implement partial caching

Now, if a run gets interrupted halfway, you can easily resume
parent 8fe59e59
......@@ -6,6 +6,9 @@ from lm_eval.metrics import mean
class LM(abc.ABC):
def __init__(self):
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context.
......@@ -60,6 +63,9 @@ class LM(abc.ABC):
"""
return cls()
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
......@@ -251,6 +257,21 @@ def hash_args(attr, args):
return hashlib.sha256(dat.encode('utf-8')).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db):
self.lm = lm
......@@ -258,6 +279,9 @@ class CachingLM:
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
def fn(requests):
res = []
......@@ -293,6 +317,9 @@ class CachingLM:
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
class Request:
......
......@@ -10,6 +10,7 @@ class GPT2LM(LM):
MAX_GEN_TOKS = 256
def __init__(self, device=None, pretrained='gpt2'):
super().__init__()
if device:
self.device = torch.device(device)
else:
......@@ -69,7 +70,12 @@ class GPT2LM(LM):
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
answer = (float(logits.sum()), bool(max_equal))
# partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
res.append(answer)
return reord.get_original(res)
......@@ -103,6 +109,9 @@ class GPT2LM(LM):
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
......@@ -48,6 +48,7 @@ class GPT3LM(LM):
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
import openai
self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
......@@ -104,8 +105,13 @@ class GPT3LM(LM):
logprobs=10,
)
for resp, ctxlen in zip(response.choices, ctxlens):
res.append(get_result(resp, ctxlen))
for resp, ctxlen, (context, continuation) in zip(response.choices, ctxlens, chunk):
answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
return reord.get_original(res)
......@@ -149,13 +155,15 @@ class GPT3LM(LM):
stop=until
)
for resp in response.choices:
for resp, (context, until) in zip(response.choices, chunk):
s = resp['text']
for term in until:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return reord.get_original(res)
return reord.get_original(res)()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment