Commit 4e0410e9 authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Fix in gather for SM distributed

parent 367c2ef5
......@@ -162,8 +162,8 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment