Unverified Commit 9e6da0a7 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer: `distributed_concat`] ensure `all_gather`'s inputs are contiguous (#20951)

[trainer: distributed_concat] ensure all_gather's input are contiguous
parent 17292440
......@@ -189,7 +189,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
tensor = atleast_1d(tensor)
tensor = atleast_1d(tensor).contiguous()
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment