Commit 3a51d909 authored by Cautiousss's avatar Cautiousss Committed by Frank Lee
Browse files

fix format (#332)


Co-authored-by: default avatar何晓昕 <cautious@r-205-106-25-172.comp.nus.edu.sg>
parent cbb6436f
...@@ -7,7 +7,7 @@ from colossalai.utils import get_current_device ...@@ -7,7 +7,7 @@ from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, next_rank=None): def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor. """Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`. synchronizes with :func:`recv_tensor_meta`.
...@@ -36,7 +36,7 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None): ...@@ -36,7 +36,7 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
def recv_tensor_meta(tensor_shape, prev_rank=None): def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor. """Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`. synchronizes with :func:`send_tensor_meta`.
...@@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor): ...@@ -104,6 +104,6 @@ def gather_split_1d_tensor(tensor):
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered return gathered
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment