Unverified Commit 1b5ea747 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Place inputs on device when include_inputs_for_metrics is True (#18046)

parent 870ff9e1
......@@ -2804,7 +2804,7 @@ class Trainer:
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available():
xm.mark_step()
......@@ -3352,7 +3352,7 @@ class Trainer:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if loss is not None:
losses = loss.repeat(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment