Commit e3960fa0 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add cache hooks to HF

parent a68c3fa4
...@@ -486,6 +486,8 @@ class HFLM(LM): ...@@ -486,6 +486,8 @@ class HFLM(LM):
res.append(answer) res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
...@@ -497,26 +499,28 @@ class HFLM(LM): ...@@ -497,26 +499,28 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(
re_ord.get_reordered(), disable=(self.rank != 0)
):
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys(): if "until" in kwargs.keys():
until = gen_kwargs.pop("until") until = kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [gen_kwargs] until = [kwargs]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}"
) )
else: else:
raise ValueError( raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}" f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}"
) )
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering # first stop sequence is used to halt generation upon encountering
...@@ -539,7 +543,7 @@ class HFLM(LM): ...@@ -539,7 +543,7 @@ class HFLM(LM):
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until, stop=primary_until,
**gen_kwargs, **kwargs,
) )
cont_toks_list = cont[0].tolist() cont_toks_list = cont[0].tolist()
...@@ -556,4 +560,6 @@ class HFLM(LM): ...@@ -556,4 +560,6 @@ class HFLM(LM):
res.append(s) res.append(s)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
return re_ord.get_original(res) return re_ord.get_original(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment