"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "27174bd4fe687b06f149c74cf7f1ebc5c9f03082"
Unverified Commit 5896b3ec authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Fix `distributed_concat` with scalar tensor (#16963)

* Fix `distributed_concat` with scalar tensor

* Update trainer_pt_utils.py
parent 084c38c5
...@@ -159,8 +159,9 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> ...@@ -159,8 +159,9 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
try: try:
if isinstance(tensor, (tuple, list)): if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
if len(tensor.shape) <= 0:
tensor = tensor[None]
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())] output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
dist.all_gather(output_tensors, tensor) dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment