Unverified Commit 4e0d0e3a authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #619 from EleutherAI/cachinglm-only

[Refactor] CachingLM support via `--use_cache`
parents 9dea125b 5a5442ff
import abc
import os
from typing import Union
from sqlitedict import SqliteDict
import json
import hashlib
from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
class LM(abc.ABC):
......@@ -12,6 +19,7 @@ class LM(abc.ABC):
(inputs/outputs should be tokenization-agnostic.)
"""
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests):
......@@ -118,3 +126,104 @@ class LM(abc.ABC):
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return 1
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
### SQLite-based caching of LM responses
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db):
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests):
res = []
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests):
hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned = True
res.append(None)
remaining_reqs.append(req)
elif hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None:
resptr += 1
res[resptr] = r
# caching
hsh = hash_args(attr, req.args)
self.dbdict[hsh] = r
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
......@@ -52,7 +52,6 @@ class TaskConfig(dict):
task: str = None
group: Union[str, list] = None
reference: str = None
dataset_path: str = None
dataset_name: str = None
......@@ -67,6 +66,8 @@ class TaskConfig(dict):
doc_to_target: Union[Callable, str] = None
use_prompt: str = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
num_fewshot: int = 0
batch_size: int = 1
......@@ -76,8 +77,6 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
filter_list: Union[str, list] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
......@@ -343,7 +342,7 @@ class Task(abc.ABC):
fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random()
)
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute
# TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
......@@ -773,7 +772,7 @@ class ConfigurableTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", "{}".format(choice)),
arguments=("", " {}".format(choice)),
idx=i,
**kwargs,
)
......
......@@ -39,7 +39,7 @@ def simple_evaluate(
batch_size=None,
max_batch_size=None,
device=None,
no_cache=False,
use_cache=None,
limit=None,
bootstrap_iters=100000,
check_integrity=False,
......@@ -64,8 +64,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Whether or not to cache
:param use_cache: str, optional
A path to a sqlite db file for caching model responses. `None` if not caching.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
......@@ -99,6 +99,16 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM)
lm = model
if use_cache is not None:
print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity:
......@@ -127,7 +137,7 @@ def simple_evaluate(
if hasattr(lm, "batch_sizes")
else [],
"device": device,
"no_cache": no_cache,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
}
......
......@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests:
return []
requests = [req.args for req in requests]
res = []
for request in tqdm(requests):
inp = request[0]
......@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res
def _model_call(self, inps):
......
......@@ -486,6 +486,8 @@ class HFLM(LM):
res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
......@@ -497,26 +499,28 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
for context, gen_kwargs in tqdm(
re_ord.get_reordered(), disable=(self.rank != 0)
):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
......@@ -539,7 +543,7 @@ class HFLM(LM):
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
**kwargs,
)
cont_toks_list = cont[0].tolist()
......@@ -556,4 +560,6 @@ class HFLM(LM):
res.append(s)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
return re_ord.get_original(res)
......@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
......@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
try:
until = request_args["until"][
0
] # TODO: does this handle a list of stop seqs correctly?
except KeyError:
until = "<|endoftext|>"
response = oa_completion(
engine=self.engine,
prompt=inps,
......@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"]
until_ = args_.get(["until"], [])
for term in until_:
s = s.split(term)[0]
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s)
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s)
......
......@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"]
is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else:
logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}"
......@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp:
s = resp["text"]
res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
......
......@@ -39,7 +39,7 @@ def parse_args():
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_cache", type=str, default=None)
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
......@@ -85,7 +85,7 @@ def main():
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
no_cache=args.no_cache,
use_cache=args.use_cache,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment