Unverified Commit deecaa1a authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #723 from EleutherAI/fix_max_length

[Refactor] Fix Max Length arg
parents d88a566c 53105bba
......@@ -990,7 +990,6 @@ class ConfigurableTask(Task):
choices = self.doc_to_choice(doc)
gold = choices[gold]
print(self._metric_fn_list)
for key, result in zip(self._metric_fn_list.keys(), results):
if self.multiple_target:
# in the case where we have multiple targets,
......
......@@ -757,11 +757,13 @@ class HFLM(LM):
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
if "max_length" not in kwargs:
kwargs["max_length"] = (context_enc.shape[1] + max_gen_toks,)
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment