Unverified Commit 39f1dff5 authored by TranSirius's avatar TranSirius Committed by GitHub
Browse files

Fix a Bug, trainer_seq2seq.py, in the else branch at Line 172,...

Fix a Bug, trainer_seq2seq.py, in the else branch at Line 172, generation_inputs should be a dict (#14546)

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()

* fix bug, trainer_seq2seq.py, Line 172, generation_inputs must be a dict before feeding into self.model.generation()
parent 2171695c
...@@ -169,7 +169,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -169,7 +169,7 @@ class Seq2SeqTrainer(Trainer):
# very ugly hack to make it work # very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0]) generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else: else:
generation_inputs = inputs["input_ids"] generation_inputs = {"input_ids": inputs["input_ids"]}
generated_tokens = self.model.generate( generated_tokens = self.model.generate(
**generation_inputs, **generation_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment