Unverified Commit 55bb4c06 authored by jglaser's avatar jglaser Committed by GitHub
Browse files

Fix exception in prediction loop occurring for certain batch sizes (#12350)



* fix distributed_concat for scalar outputs

* Update README.md

* fixed typo (#12356)

* simplify fix with terser syntax
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Trigger CI
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarmichal pitr <21157924+MichalPitr@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d4ce31e8
......@@ -155,6 +155,7 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment