Unverified Commit 269c3d14 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix gather for TPU (#13813)

parent 7db2a79b
...@@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name): ...@@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name):
if isinstance(tensors, (list, tuple)): if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
if tensors.ndim == 0:
tensors = tensors[None]
return xm.mesh_reduce(name, tensors, torch.cat) return xm.mesh_reduce(name, tensors, torch.cat)
else: else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment