"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "aa4198a238f915e7ac04bc43d28ddbcb7fe690df"
Unverified Commit 9e6da0a7 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer: `distributed_concat`] ensure `all_gather`'s inputs are contiguous (#20951)

[trainer: distributed_concat] ensure all_gather's input are contiguous
parent 17292440
...@@ -189,7 +189,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> ...@@ -189,7 +189,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try: try:
if isinstance(tensor, (tuple, list)): if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
tensor = atleast_1d(tensor) tensor = atleast_1d(tensor).contiguous()
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())] output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor) dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment