Unverified Commit 244e1b5b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix #7304 (#7305)

parent e4610881
...@@ -1334,9 +1334,9 @@ class Trainer: ...@@ -1334,9 +1334,9 @@ class Trainer:
elif is_torch_tpu_available(): elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None: if preds is not None:
preds = nested_xla_mesh_reduce("eval_preds", preds) preds = nested_xla_mesh_reduce(preds, "eval_preds")
if label_ids is not None: if label_ids is not None:
label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat) label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
if eval_losses is not None: if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment