Commit 197a2c89 authored by Zangwei Zheng's avatar Zangwei Zheng Committed by Frank Lee
Browse files

[NFC] polish colossalai/communication/collective.py (#1262)

parent f1cafcc7
...@@ -10,10 +10,7 @@ from colossalai.context import ParallelMode ...@@ -10,10 +10,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
def all_gather(tensor: Tensor, def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
dim: int,
parallel_mode: ParallelMode,
async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a r"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
...@@ -163,11 +160,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b ...@@ -163,11 +160,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
return out return out
def reduce(tensor: Tensor, def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
dst: int,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False):
r"""Reduce tensors across whole parallel group. Only the process with r"""Reduce tensors across whole parallel group. Only the process with
rank ``dst`` is going to receive the final result. rank ``dst`` is going to receive the final result.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment