"tests/vscode:/vscode.git/clone" did not exist on "013955b5a72d25836c519d6682d8708330351965"
Unverified Commit c5dbf289 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into add-back-cache

parents 03b9db6b 4e0d0e3a
...@@ -6,6 +6,8 @@ from sqlitedict import SqliteDict ...@@ -6,6 +6,8 @@ from sqlitedict import SqliteDict
import json import json
import hashlib import hashlib
from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -178,14 +180,17 @@ class CachingLM: ...@@ -178,14 +180,17 @@ class CachingLM:
remaining_reqs = [] remaining_reqs = []
warned = False warned = False
# figure out which ones are cached and which ones are new # figure out which ones are cached and which ones are new
for req in requests: eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests):
hsh = hash_args(attr, req.args) hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False): if attr == "greedy_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache # when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1). # (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned: if not warned:
eval_logger.warning( eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed." f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
) )
warned = True warned = True
res.append(None) res.append(None)
......
...@@ -88,6 +88,8 @@ class AnthropicLM(LM): ...@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests: if not requests:
return [] return []
requests = [req.args for req in requests]
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] inp = request[0]
...@@ -102,6 +104,9 @@ class AnthropicLM(LM): ...@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until, stop=until,
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res return res
def _model_call(self, inps): def _model_call(self, inps):
......
...@@ -564,11 +564,11 @@ class HFLM(LM): ...@@ -564,11 +564,11 @@ class HFLM(LM):
until = [kwargs] until = [kwargs]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
) )
else: else:
raise ValueError( raise ValueError(
f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}" f"Expected `kwargs` to be of type `dict` but got {kwargs}"
) )
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
......
...@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM): ...@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm( for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
): ):
inps = [] inps = []
...@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM): ...@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
try:
until = request_args["until"][
0
] # TODO: does this handle a list of stop seqs correctly?
except KeyError:
until = "<|endoftext|>"
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
...@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM): ...@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop=until, stop=until,
) )
for resp, (context, until_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"] s = resp["text"]
until_ = args_.get(["until"], [])
for term in until_: for term in until_:
s = s.split(term)[0] if len(term) > 0:
s = s.split(term)[0]
# partial caching # partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s) self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s) res.append(s)
......
...@@ -101,6 +101,10 @@ class TextSynthLM(LM): ...@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"] logprob = resp["logprob"]
is_greedy = resp["is_greedy"] is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy)) res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else: else:
logger.error( logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}" f"The following response does not contain `logprobs`. Got:\n{resp}"
...@@ -141,6 +145,8 @@ class TextSynthLM(LM): ...@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp: if "text" in resp:
s = resp["text"] s = resp["text"]
res.append(s) res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " f"The following response does not contain generated `text`. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment