Unverified Commit c5dbf289 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into add-back-cache

parents 03b9db6b 4e0d0e3a
......@@ -6,6 +6,8 @@ from sqlitedict import SqliteDict
import json
import hashlib
from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
......@@ -178,14 +180,17 @@ class CachingLM:
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
for req in requests:
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests):
hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed."
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned = True
res.append(None)
......
......@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests:
return []
requests = [req.args for req in requests]
res = []
for request in tqdm(requests):
inp = request[0]
......@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res
def _model_call(self, inps):
......
......@@ -564,11 +564,11 @@ class HFLM(LM):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}"
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}"
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
......
......@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
......@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
try:
until = request_args["until"][
0
] # TODO: does this handle a list of stop seqs correctly?
except KeyError:
until = "<|endoftext|>"
response = oa_completion(
engine=self.engine,
prompt=inps,
......@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"]
until_ = args_.get(["until"], [])
for term in until_:
s = s.split(term)[0]
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s)
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s)
......
......@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"]
is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else:
logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}"
......@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp:
s = resp["text"]
res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment