"src/fastertransformer/models/llama/LlamaDecoder.cc" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Commit e3960fa0 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add cache hooks to HF

parent a68c3fa4
......@@ -486,6 +486,8 @@ class HFLM(LM):
res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
......@@ -497,26 +499,28 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
for context, gen_kwargs in tqdm(
re_ord.get_reordered(), disable=(self.rank != 0)
):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
f"Expected `generation_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
f"Expected `generation_kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
......@@ -539,7 +543,7 @@ class HFLM(LM):
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
**kwargs,
)
cont_toks_list = cont[0].tolist()
......@@ -556,4 +560,6 @@ class HFLM(LM):
res.append(s)
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
return re_ord.get_original(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment