Unverified Commit 0270256b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow nested tensors in predicted logits (#7542)

parent 60de910e
...@@ -48,6 +48,7 @@ from .trainer_utils import ( ...@@ -48,6 +48,7 @@ from .trainer_utils import (
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
nested_concat, nested_concat,
nested_detach,
nested_numpify, nested_numpify,
nested_xla_mesh_reduce, nested_xla_mesh_reduce,
set_seed, set_seed,
...@@ -1466,16 +1467,18 @@ class Trainer: ...@@ -1466,16 +1467,18 @@ class Trainer:
logits = outputs[:] logits = outputs[:]
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only: if prediction_loss_only:
return (loss, None, None) return (loss, None, None)
logits = tuple(logit.detach() for logit in logits) logits = nested_detach(logits)
if len(logits) == 1: if len(logits) == 1:
logits = logits[0] logits = logits[0]
if has_labels: if has_labels:
labels = tuple(inputs.get(name).detach() for name in self.label_names) labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1: if len(labels) == 1:
labels = labels[0] labels = labels[0]
else: else:
......
...@@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0): ...@@ -154,6 +154,13 @@ def nested_concat(tensors, new_tensors, dim=0):
raise ImportError("Torch must be installed to use `nested_concat`") raise ImportError("Torch must be installed to use `nested_concat`")
def nested_deatch(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_numpify(tensors): def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)." "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment