Commit ba864e09 authored by lintangsutawika's avatar lintangsutawika
Browse files

fix to add a case for if a user add `max_length` to generation_kwargs

parent d88a566c
...@@ -757,11 +757,13 @@ class HFLM(LM): ...@@ -757,11 +757,13 @@ class HFLM(LM):
context_enc = context_enc.to(self.device) context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device) attn_masks = attn_masks.to(self.device)
if "max_length" not in kwargs:
kwargs["max_length"] = (context_enc.shape[1] + max_gen_toks,)
# perform batched generation # perform batched generation
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
attention_mask=attn_masks, attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until, stop=primary_until,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment