Unverified Commit 0270256b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow nested tensors in predicted logits (#7542)

parent 60de910e
......@@ -48,6 +48,7 @@ from .trainer_utils import (
distributed_broadcast_scalars,
distributed_concat,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
set_seed,
......@@ -1466,16 +1467,18 @@ class Trainer:
logits = outputs[:]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only:
return (loss, None, None)
logits = tuple(logit.detach() for logit in logits)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.label_names)
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
......
......@@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0):
raise ImportError("Torch must be installed to use `nested_concat`")
def nested_deatch(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment