Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d3d5bedc
Commit
d3d5bedc
authored
Apr 01, 2022
by
xyupeng
Committed by
binmakeswell
Apr 06, 2022
Browse files
fix format (#607)
parent
f2d2a159
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
19 deletions
+13
-19
colossalai/communication/p2p.py
colossalai/communication/p2p.py
+13
-19
No files found.
colossalai/communication/p2p.py
View file @
d3d5bedc
...
@@ -12,7 +12,6 @@ from functools import reduce
...
@@ -12,7 +12,6 @@ from functools import reduce
import
operator
import
operator
from
.utils
import
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
from
.utils
import
split_tensor_into_1d_equal_chunks
,
gather_split_1d_tensor
TensorShape
=
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
TensorShape
=
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
...
@@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None,
...
@@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None,
if
tensor_send_prev
is
not
None
or
recv_prev
:
if
tensor_send_prev
is
not
None
or
recv_prev
:
if
prev_rank
is
None
:
if
prev_rank
is
None
:
prev_rank
=
gpc
.
get_prev_global_rank
(
prev_rank
=
gpc
.
get_prev_global_rank
(
ParallelMode
.
PIPELINE
)
ParallelMode
.
PIPELINE
)
if
tensor_send_next
is
not
None
or
recv_next
:
if
tensor_send_next
is
not
None
or
recv_next
:
if
next_rank
is
None
:
if
next_rank
is
None
:
next_rank
=
gpc
.
get_next_global_rank
(
next_rank
=
gpc
.
get_next_global_rank
(
ParallelMode
.
PIPELINE
)
ParallelMode
.
PIPELINE
)
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
send_prev_split
=
_get_tensor_shape
(
tensor_send_prev
.
shape
,
scatter_gather_tensors
)[
1
]
send_prev_split
=
_get_tensor_shape
(
tensor_send_prev
.
shape
,
scatter_gather_tensors
)[
1
]
...
@@ -183,9 +180,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
...
@@ -183,9 +180,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
next_rank (int, optional): The rank of the recipient of the tensor.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
"""
if
not
gpc
.
is_pipeline_last_stage
():
if
not
gpc
.
is_pipeline_last_stage
():
_communicate
(
tensor_send_next
=
output_tensor
,
_communicate
(
tensor_send_next
=
output_tensor
,
next_rank
=
next_rank
,
scatter_gather_tensors
=
scatter_gather_tensors
)
next_rank
=
next_rank
,
scatter_gather_tensors
=
scatter_gather_tensors
)
def
send_backward
(
input_tensor_grad
,
prev_rank
=
None
,
scatter_gather_tensors
=
False
):
def
send_backward
(
input_tensor_grad
,
prev_rank
=
None
,
scatter_gather_tensors
=
False
):
...
@@ -338,8 +333,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
...
@@ -338,8 +333,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
Returns:
Returns:
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
"""
"""
input_tensor
,
output_tensor_grad
=
_communicate
(
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment