Unverified Commit 2c42b230 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

updated collective ops api (#1054)

parent 51b9a496
...@@ -13,7 +13,6 @@ from colossalai.core import global_context as gpc ...@@ -13,7 +13,6 @@ from colossalai.core import global_context as gpc
def all_gather(tensor: Tensor, def all_gather(tensor: Tensor,
dim: int, dim: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a r"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
...@@ -26,7 +25,6 @@ def all_gather(tensor: Tensor, ...@@ -26,7 +25,6 @@ def all_gather(tensor: Tensor,
tensor (:class:`torch.Tensor`): Tensor to be gathered. tensor (:class:`torch.Tensor`): Tensor to be gathered.
dim (int): The dimension concatenating in. dim (int): The dimension concatenating in.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -43,7 +41,7 @@ def all_gather(tensor: Tensor, ...@@ -43,7 +41,7 @@ def all_gather(tensor: Tensor,
shape[0] *= depth shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
temp = list(torch.chunk(out, depth, dim=0)) temp = list(torch.chunk(out, depth, dim=0))
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_gather(tensor_list=temp, work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(), tensor=tensor.transpose(0, dim).contiguous(),
group=group, group=group,
...@@ -59,7 +57,6 @@ def reduce_scatter(tensor: Tensor, ...@@ -59,7 +57,6 @@ def reduce_scatter(tensor: Tensor,
dim: int, dim: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces all tensors then scatters it in a specific dimension to all r"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
...@@ -76,7 +73,6 @@ def reduce_scatter(tensor: Tensor, ...@@ -76,7 +73,6 @@ def reduce_scatter(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_. `ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -90,7 +86,7 @@ def reduce_scatter(tensor: Tensor, ...@@ -90,7 +86,7 @@ def reduce_scatter(tensor: Tensor,
else: else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
...@@ -101,7 +97,6 @@ def reduce_scatter(tensor: Tensor, ...@@ -101,7 +97,6 @@ def reduce_scatter(tensor: Tensor,
def all_reduce(tensor: Tensor, def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
...@@ -116,7 +111,6 @@ def all_reduce(tensor: Tensor, ...@@ -116,7 +111,6 @@ def all_reduce(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_. `ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -129,7 +123,7 @@ def all_reduce(tensor: Tensor, ...@@ -129,7 +123,7 @@ def all_reduce(tensor: Tensor,
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_reduce(out, op=op, group=group, async_op=async_op) work = dist.all_reduce(out, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
...@@ -137,7 +131,7 @@ def all_reduce(tensor: Tensor, ...@@ -137,7 +131,7 @@ def all_reduce(tensor: Tensor,
return out return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: bool = False, async_op: bool = False): def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
r"""Broadcast tensors to whole parallel group. Tensor must have the same r"""Broadcast tensors to whole parallel group. Tensor must have the same
number of elements in all processes participating in the collective. number of elements in all processes participating in the collective.
...@@ -149,7 +143,6 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo ...@@ -149,7 +143,6 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo
tensor (:class:`torch.Tensor`): Tensor to be broadcast. tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank. src (int): Source rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -162,7 +155,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo ...@@ -162,7 +155,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: boo
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.broadcast(out, src=src, group=group, async_op=async_op) work = dist.broadcast(out, src=src, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
...@@ -174,7 +167,6 @@ def reduce(tensor: Tensor, ...@@ -174,7 +167,6 @@ def reduce(tensor: Tensor,
dst: int, dst: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False): async_op: bool = False):
r"""Reduce tensors across whole parallel group. Only the process with r"""Reduce tensors across whole parallel group. Only the process with
rank ``dst`` is going to receive the final result. rank ``dst`` is going to receive the final result.
...@@ -187,7 +179,6 @@ def reduce(tensor: Tensor, ...@@ -187,7 +179,6 @@ def reduce(tensor: Tensor,
tensor (:class:`torch.Tensor`): Tensor to be reduced. tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank. dst (int): Destination rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -200,7 +191,7 @@ def reduce(tensor: Tensor, ...@@ -200,7 +191,7 @@ def reduce(tensor: Tensor,
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op) work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment