Unverified Commit 4e0d0e3a authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #619 from EleutherAI/cachinglm-only

[Refactor] CachingLM support via `--use_cache`
parents 9dea125b 5a5442ff
import abc import abc
import os
from typing import Union from typing import Union
from sqlitedict import SqliteDict
import json
import hashlib
from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger
class LM(abc.ABC): class LM(abc.ABC):
...@@ -12,6 +19,7 @@ class LM(abc.ABC): ...@@ -12,6 +19,7 @@ class LM(abc.ABC):
(inputs/outputs should be tokenization-agnostic.) (inputs/outputs should be tokenization-agnostic.)
""" """
self.cache_hook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -118,3 +126,104 @@ class LM(abc.ABC): ...@@ -118,3 +126,104 @@ class LM(abc.ABC):
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return 1 return 1
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
### SQLite-based caching of LM responses
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db):
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests):
res = []
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests):
hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned = True
res.append(None)
remaining_reqs.append(req)
elif hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None:
resptr += 1
res[resptr] = r
# caching
hsh = hash_args(attr, req.args)
self.dbdict[hsh] = r
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
...@@ -52,7 +52,6 @@ class TaskConfig(dict): ...@@ -52,7 +52,6 @@ class TaskConfig(dict):
task: str = None task: str = None
group: Union[str, list] = None group: Union[str, list] = None
reference: str = None
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
...@@ -67,6 +66,8 @@ class TaskConfig(dict): ...@@ -67,6 +66,8 @@ class TaskConfig(dict):
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
use_prompt: str = None use_prompt: str = None
description: str = "" description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1 batch_size: int = 1
...@@ -76,8 +77,6 @@ class TaskConfig(dict): ...@@ -76,8 +77,6 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
...@@ -343,7 +342,7 @@ class Task(abc.ABC): ...@@ -343,7 +342,7 @@ class Task(abc.ABC):
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random() doc, self._config.num_fewshot, rnd=random.Random()
) )
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
...@@ -773,7 +772,7 @@ class ConfigurableTask(Task): ...@@ -773,7 +772,7 @@ class ConfigurableTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=("", "{}".format(choice)), arguments=("", " {}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
...@@ -39,7 +39,7 @@ def simple_evaluate( ...@@ -39,7 +39,7 @@ def simple_evaluate(
batch_size=None, batch_size=None,
max_batch_size=None, max_batch_size=None,
device=None, device=None,
no_cache=False, use_cache=None,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters=100000,
check_integrity=False, check_integrity=False,
...@@ -64,8 +64,8 @@ def simple_evaluate( ...@@ -64,8 +64,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection Maximal batch size to try with automatic batch size detection
:param device: str, optional :param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool :param use_cache: str, optional
Whether or not to cache A path to a sqlite db file for caching model responses. `None` if not caching.
:param limit: int or float, optional :param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters: :param bootstrap_iters:
...@@ -99,6 +99,16 @@ def simple_evaluate( ...@@ -99,6 +99,16 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM) assert isinstance(model, lm_eval.api.model.LM)
lm = model lm = model
if use_cache is not None:
print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity: if check_integrity:
...@@ -127,7 +137,7 @@ def simple_evaluate( ...@@ -127,7 +137,7 @@ def simple_evaluate(
if hasattr(lm, "batch_sizes") if hasattr(lm, "batch_sizes")
else [], else [],
"device": device, "device": device,
"no_cache": no_cache, "use_cache": use_cache,
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
} }
......
...@@ -88,6 +88,8 @@ class AnthropicLM(LM): ...@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests: if not requests:
return [] return []
requests = [req.args for req in requests]
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] inp = request[0]
...@@ -102,6 +104,9 @@ class AnthropicLM(LM): ...@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until, stop=until,
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res return res
def _model_call(self, inps): def _model_call(self, inps):
......
...@@ -486,6 +486,8 @@ class HFLM(LM): ...@@ -486,6 +486,8 @@ class HFLM(LM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
...@@ -497,26 +499,28 @@ class HFLM(LM): ...@@ -497,26 +499,28 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(
re_ord.get_reordered(), disable=(self.rank != 0)
):
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys(): if "until" in kwargs.keys():
until = gen_kwargs.pop("until") until = kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [gen_kwargs] until = [kwargs]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
) )
else: else:
raise ValueError( raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}" f"Expected `kwargs` to be of type `dict` but got {kwargs}"
) )
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering # first stop sequence is used to halt generation upon encountering
...@@ -539,7 +543,7 @@ class HFLM(LM): ...@@ -539,7 +543,7 @@ class HFLM(LM):
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until, stop=primary_until,
**gen_kwargs, **kwargs,
) )
cont_toks_list = cont[0].tolist() cont_toks_list = cont[0].tolist()
...@@ -556,4 +560,6 @@ class HFLM(LM): ...@@ -556,4 +560,6 @@ class HFLM(LM):
res.append(s) res.append(s)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
return re_ord.get_original(res) return re_ord.get_original(res)
...@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM): ...@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield ret, lastuntil yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm( for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
): ):
inps = [] inps = []
...@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM): ...@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
try:
until = request_args["until"][
0
] # TODO: does this handle a list of stop seqs correctly?
except KeyError:
until = "<|endoftext|>"
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
...@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM): ...@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop=until, stop=until,
) )
for resp, (context, until_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"] s = resp["text"]
until_ = args_.get(["until"], [])
for term in until_: for term in until_:
s = s.split(term)[0] if len(term) > 0:
s = s.split(term)[0]
# partial caching # partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s) self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s) res.append(s)
......
...@@ -101,6 +101,10 @@ class TextSynthLM(LM): ...@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"] logprob = resp["logprob"]
is_greedy = resp["is_greedy"] is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy)) res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else: else:
logger.error( logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}" f"The following response does not contain `logprobs`. Got:\n{resp}"
...@@ -141,6 +145,8 @@ class TextSynthLM(LM): ...@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp: if "text" in resp:
s = resp["text"] s = resp["text"]
res.append(s) res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " f"The following response does not contain generated `text`. "
......
...@@ -39,7 +39,7 @@ def parse_args(): ...@@ -39,7 +39,7 @@ def parse_args():
"If <1, limit is a percentage of the total number of examples.", "If <1, limit is a percentage of the total number of examples.",
) )
parser.add_argument("--data_sampling", type=float, default=None) parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_cache", type=str, default=None)
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--write_out", action="store_true", default=False)
...@@ -85,7 +85,7 @@ def main(): ...@@ -85,7 +85,7 @@ def main():
batch_size=args.batch_size, batch_size=args.batch_size,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
device=args.device, device=args.device,
no_cache=args.no_cache, use_cache=args.use_cache,
limit=args.limit, limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment