Commit d1c5abef authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add cache hooks to API LMs

parent 732f7ed2
...@@ -88,6 +88,8 @@ class AnthropicLM(LM): ...@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests: if not requests:
return [] return []
requests = [req.args for req in requests]
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] inp = request[0]
...@@ -102,6 +104,9 @@ class AnthropicLM(LM): ...@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until, stop=until,
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res return res
def _model_call(self, inps): def _model_call(self, inps):
......
...@@ -101,6 +101,10 @@ class TextSynthLM(LM): ...@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"] logprob = resp["logprob"]
is_greedy = resp["is_greedy"] is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy)) res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else: else:
logger.error( logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}" f"The following response does not contain `logprobs`. Got:\n{resp}"
...@@ -141,6 +145,8 @@ class TextSynthLM(LM): ...@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp: if "text" in resp:
s = resp["text"] s = resp["text"]
res.append(s) res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " f"The following response does not contain generated `text`. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment