Unverified Commit 091693b4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Seq2SeqTrainer] Remove model input name hack (#14802)

* [Seq2SeqTrainer] Remove model input name hack

* Update src/transformers/trainer_seq2seq.py

* make style

* finish
parent 84ea427f
...@@ -159,12 +159,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -159,12 +159,8 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False, "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
} }
if self.tokenizer is not None: model_input_names = self.tokenizer.model_input_names if self.tokenizer is not None else ["input_ids"]
generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names} generation_inputs = {k: v for k, v in inputs.items() if k in model_input_names}
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = {"input_ids": inputs["input_ids"]}
generated_tokens = self.model.generate( generated_tokens = self.model.generate(
**generation_inputs, **generation_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment