Commit d1c5abef authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add cache hooks to API LMs

parent 732f7ed2
......@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests:
return []
requests = [req.args for req in requests]
res = []
for request in tqdm(requests):
inp = request[0]
......@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res
def _model_call(self, inps):
......
......@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"]
is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else:
logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}"
......@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp:
s = resp["text"]
res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment