Commit 3a51d909 authored by Cautiousss's avatar Cautiousss Committed by Frank Lee
Browse files

fix format (#332)


Co-authored-by: default avatar何晓昕 <cautious@r-205-106-25-172.comp.nus.edu.sg>
parent cbb6436f
...@@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor): ...@@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor):
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered return gathered
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment