Unverified Commit 1b5ea747 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Place inputs on device when include_inputs_for_metrics is True (#18046)

parent 870ff9e1
...@@ -2804,7 +2804,7 @@ class Trainer: ...@@ -2804,7 +2804,7 @@ class Trainer:
# Prediction step # Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.mark_step() xm.mark_step()
...@@ -3352,7 +3352,7 @@ class Trainer: ...@@ -3352,7 +3352,7 @@ class Trainer:
for step, inputs in enumerate(dataloader): for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if loss is not None: if loss is not None:
losses = loss.repeat(batch_size) losses = loss.repeat(batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment