Unverified Commit a7d21318 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use main_input_name for include_inputs_for_metrics (#24993)

parent a6484c89
......@@ -3121,7 +3121,8 @@ class Trainer:
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available():
xm.mark_step()
......@@ -3674,7 +3675,8 @@ class Trainer:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
if loss is not None:
losses = loss.repeat(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment