Commit efbe6e7f authored by Leo Gao's avatar Leo Gao
Browse files

Implement partial caching

Now, if a run gets interrupted halfway, you can easily resume
parent 8fe59e59
...@@ -6,6 +6,9 @@ from lm_eval.metrics import mean ...@@ -6,6 +6,9 @@ from lm_eval.metrics import mean
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self):
self.cache_hook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests): def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
...@@ -60,6 +63,9 @@ class LM(abc.ABC): ...@@ -60,6 +63,9 @@ class LM(abc.ABC):
""" """
return cls() return cls()
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
...@@ -251,6 +257,21 @@ def hash_args(attr, args): ...@@ -251,6 +257,21 @@ def hash_args(attr, args):
return hashlib.sha256(dat.encode('utf-8')).hexdigest() return hashlib.sha256(dat.encode('utf-8')).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM: class CachingLM:
def __init__(self, lm, cache_db): def __init__(self, lm, cache_db):
self.lm = lm self.lm = lm
...@@ -258,6 +279,9 @@ class CachingLM: ...@@ -258,6 +279,9 @@ class CachingLM:
os.makedirs(os.path.dirname(cache_db), exist_ok=True) os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True) self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr): def __getattr__(self, attr):
def fn(requests): def fn(requests):
res = [] res = []
...@@ -293,6 +317,9 @@ class CachingLM: ...@@ -293,6 +317,9 @@ class CachingLM:
return res return res
return fn return fn
def get_cache_hook(self):
return CacheHook(self)
class Request: class Request:
......
...@@ -10,6 +10,7 @@ class GPT2LM(LM): ...@@ -10,6 +10,7 @@ class GPT2LM(LM):
MAX_GEN_TOKS = 256 MAX_GEN_TOKS = 256
def __init__(self, device=None, pretrained='gpt2'): def __init__(self, device=None, pretrained='gpt2'):
super().__init__()
if device: if device:
self.device = torch.device(device) self.device = torch.device(device)
else: else:
...@@ -69,7 +70,12 @@ class GPT2LM(LM): ...@@ -69,7 +70,12 @@ class GPT2LM(LM):
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal))) answer = (float(logits.sum()), bool(max_equal))
# partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
res.append(answer)
return reord.get_original(res) return reord.get_original(res)
...@@ -103,6 +109,9 @@ class GPT2LM(LM): ...@@ -103,6 +109,9 @@ class GPT2LM(LM):
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)
...@@ -48,6 +48,7 @@ class GPT3LM(LM): ...@@ -48,6 +48,7 @@ class GPT3LM(LM):
:param truncate: bool :param truncate: bool
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
...@@ -104,8 +105,13 @@ class GPT3LM(LM): ...@@ -104,8 +105,13 @@ class GPT3LM(LM):
logprobs=10, logprobs=10,
) )
for resp, ctxlen in zip(response.choices, ctxlens): for resp, ctxlen, (context, continuation) in zip(response.choices, ctxlens, chunk):
res.append(get_result(resp, ctxlen)) answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
self.cache_hook.add_partial("loglikelihood", (context, continuation), answer)
return reord.get_original(res) return reord.get_original(res)
...@@ -149,13 +155,15 @@ class GPT3LM(LM): ...@@ -149,13 +155,15 @@ class GPT3LM(LM):
stop=until stop=until
) )
for resp in response.choices: for resp, (context, until) in zip(response.choices, chunk):
s = resp['text'] s = resp['text']
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment