Unverified Commit fb30b7c7 authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Fix] Fix gen inferencer (#615)

parent 721a45c6
...@@ -130,8 +130,8 @@ class GenInferencer(BaseInferencer): ...@@ -130,8 +130,8 @@ class GenInferencer(BaseInferencer):
entry, max_out_len=self.max_out_len) entry, max_out_len=self.max_out_len)
generated = results generated = results
num_return_sequences = self.model.get('generation_kwargs', {}).get( num_return_sequences = getattr(self.model, 'generation_kwargs',
'num_return_sequences', 1) {}).get('num_return_sequences', 1)
# 5-3. Save current output # 5-3. Save current output
for prompt, prediction, gold in zip( for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences), parsed_entries, batched(generated, num_return_sequences),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment