Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8004c8e9
Unverified
Commit
8004c8e9
authored
Apr 25, 2022
by
Frank Lee
Committed by
GitHub
Apr 25, 2022
Browse files
[doc] improved docstring in the communication module (#863)
parent
8af5f742
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
39 deletions
+49
-39
colossalai/communication/collective.py
colossalai/communication/collective.py
+1
-1
colossalai/communication/p2p.py
colossalai/communication/p2p.py
+32
-25
colossalai/communication/ring.py
colossalai/communication/ring.py
+3
-3
colossalai/communication/utils.py
colossalai/communication/utils.py
+13
-10
No files found.
colossalai/communication/collective.py
View file @
8004c8e9
...
...
@@ -208,7 +208,7 @@ def reduce(tensor: Tensor,
return
out
def
scatter_object_list
(
scatter_object_output_list
,
scatter_object_input_list
,
src
=
0
,
group
=
None
):
def
scatter_object_list
(
scatter_object_output_list
,
scatter_object_input_list
,
src
=
0
,
group
=
None
)
->
None
:
r
"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if
dist
.
distributed_c10d
.
_rank_not_in_group
(
group
):
...
...
colossalai/communication/p2p.py
View file @
8004c8e9
...
...
@@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
Returns:
Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
Tuple[Union[
:class:`
torch.Size
`
, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
"""
if
chunk_tensor
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
...
...
@@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return
tensor_chunk_shape
,
chunk_tensor
def
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
recv_prev_shape
=
None
,
recv_next_shape
=
None
,
prev_rank
=
None
,
next_rank
=
None
,
dtype
=
None
,
scatter_gather_tensors
=
False
)
:
def
_communicate
(
tensor_send_next
:
torch
.
Tensor
=
None
,
tensor_send_prev
:
torch
.
Tensor
=
None
,
recv_prev
:
bool
=
False
,
recv_next
:
bool
=
False
,
recv_prev_shape
:
TensorShape
=
None
,
recv_next_shape
:
TensorShape
=
None
,
prev_rank
:
int
=
None
,
next_rank
:
int
=
None
,
dtype
:
torch
.
dtype
=
None
,
scatter_gather_tensors
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
]
:
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
tensor_send_next
(:class:`torch.Tensor`)
: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
tensor_send_prev
(:class:`torch.Tensor`)
: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
recv_prev
(bool)
: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
recv_next
(bool)
: boolean for whether tensor should be received from
next rank.
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
next_rank (int): the rank of the next pipeline stage, defualts to None,
dtype (torch.dtype): data type of intermediate buffers, defaults to None
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
Returns:
(
tensor_recv_prev, tensor_recv_next
)
Tuple[torch.Tensor]: returns
tensor_recv_prev, tensor_recv_next
"""
# Create placeholder tensors for receive in forward and backward directions
...
...
@@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None,
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
input_tensor_shape
,
prev_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
def
recv_forward
(
input_tensor_shape
,
prev_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
)
->
torch
.
Tensor
:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
...
...
@@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
return
input_tensor
def
recv_backward
(
output_grad_shape
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
def
recv_backward
(
output_grad_shape
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
)
->
torch
.
Tensor
:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
...
...
@@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
return
output_tensor_grad
def
send_forward
(
output_tensor
,
next_rank
=
None
,
scatter_gather_tensors
=
False
):
def
send_forward
(
output_tensor
,
next_rank
=
None
,
scatter_gather_tensors
=
False
)
->
None
:
"""Sends the input tensor to the next stage in pipeline.
Args:
...
...
@@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
_communicate
(
tensor_send_next
=
output_tensor
,
next_rank
=
next_rank
,
scatter_gather_tensors
=
scatter_gather_tensors
)
def
send_backward
(
input_tensor_grad
,
prev_rank
=
None
,
scatter_gather_tensors
=
False
):
def
send_backward
(
input_tensor_grad
,
prev_rank
=
None
,
scatter_gather_tensors
=
False
)
->
None
:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
...
...
@@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor,
recv_next
=
True
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
scatter_gather_tensors
=
False
)
->
torch
.
Tensor
:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
...
...
@@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev
=
True
,
prev_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
scatter_gather_tensors
=
False
)
->
torch
.
Tensor
:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
...
...
@@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor,
prev_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
scatter_gather_tensors
=
False
)
->
torch
.
Tensor
:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
...
...
@@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad,
prev_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
scatter_gather_tensors
=
False
)
->
torch
.
Tensor
:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage.
...
...
@@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
prev_rank
=
None
,
next_rank
=
None
,
dtype
=
torch
.
float
,
scatter_gather_tensors
=
False
):
scatter_gather_tensors
=
False
)
->
Tuple
[
torch
.
Tensor
]
:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage.
...
...
colossalai/communication/ring.py
View file @
8004c8e9
...
...
@@ -8,13 +8,13 @@ from colossalai.core import global_context as gpc
from
colossalai.utils
import
get_current_device
,
synchronize
def
ring_forward
(
tensor_send_next
:
torch
.
Tensor
,
parallel_mode
:
ParallelMode
):
def
ring_forward
(
tensor_send_next
:
torch
.
Tensor
,
parallel_mode
:
ParallelMode
)
->
torch
.
Tensor
:
"""Sends a tensor to the next member and receives a tensor from the previous member.
This function returns the received tensor from the previous member.
Args:
tensor_send_next: Tensor sent to next member
parallel_mode: Parallel group mode used in this communication
tensor_send_next
(:class:`torch.Tensor`)
: Tensor sent to next member
parallel_mode
(ParallelMode)
: Parallel group mode used in this communication
Returns:
:class:`torch.Tensor`: The tensor received from the previous.
...
...
colossalai/communication/utils.py
View file @
8004c8e9
...
...
@@ -4,16 +4,19 @@ import torch.distributed as dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
from
typing
import
Union
,
List
,
Tuple
TensorShape
=
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
def
send_tensor_meta
(
tensor
,
need_meta
=
True
,
next_rank
=
None
):
def
send_tensor_meta
(
tensor
,
need_meta
=
True
,
next_rank
=
None
)
->
bool
:
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
Args:
tensor (torch.Tensor): Tensor to be sent.
tensor (
:class:`
torch.Tensor
`
): Tensor to be sent.
need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group.
...
...
@@ -34,14 +37,14 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
return
False
def
recv_tensor_meta
(
tensor_shape
,
prev_rank
=
None
):
def
recv_tensor_meta
(
tensor_shape
:
TensorShape
,
prev_rank
=
None
)
->
torch
.
Size
:
"""Receives tensor meta information before receiving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be received before communications. This function
synchronizes with :func:`send_tensor_meta`.
Args:
tensor_shape (torch.Size): The shape of the tensor to be received.
tensor_shape (
:class:`
torch.Size
`
): The shape of the tensor to be received.
prev_rank (int): The rank of the source of the tensor.
Returns:
...
...
@@ -63,15 +66,15 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
return
tensor_shape
def
split_tensor_into_1d_equal_chunks
(
tensor
,
new_buffer
=
False
):
def
split_tensor_into_1d_equal_chunks
(
tensor
:
torch
.
Tensor
,
new_buffer
=
False
)
->
torch
.
Tensor
:
"""Break a tensor into equal 1D chunks.
Args:
tensor (torch.Tensor): Tensor to be split before communication.
tensor (
:class:`
torch.Tensor
`
): Tensor to be split before communication.
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
Returns:
:class:`torch.
Size
`: The split tensor
:class:`torch.
Tensor
`: The split tensor
"""
partition_size
=
torch
.
numel
(
tensor
)
//
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
start_index
=
partition_size
*
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
...
...
@@ -84,13 +87,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
return
data
def
gather_split_1d_tensor
(
tensor
)
:
def
gather_split_1d_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Opposite of above function, gather values from model parallel ranks.
Args:
tensor (torch.Tensor): Tensor to be gathered after communication.
tensor (
:class:`
torch.Tensor
`
): Tensor to be gathered after communication.
Returns:
:class:`torch.
Size
`: The gathered tensor.
:class:`torch.
Tensor
`: The gathered tensor.
"""
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
numel
=
torch
.
numel
(
tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment