Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cbad126a
Commit
cbad126a
authored
Mar 03, 2021
by
Deepak Narayanan
Browse files
Bring back call to ring_exchange() in _communicate()
parent
78cf869f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
26 deletions
+36
-26
megatron/p2p_communication.py
megatron/p2p_communication.py
+36
-26
No files found.
megatron/p2p_communication.py
View file @
cbad126a
...
@@ -21,7 +21,8 @@ from megatron import get_args
...
@@ -21,7 +21,8 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
):
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
...
@@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
previous rank.
previous rank.
recv_next: boolean for whether tensor should be received from
recv_next: boolean for whether tensor should be received from
next rank.
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
Returns:
(tensor_recv_prev, tensor_recv_next)
(tensor_recv_prev, tensor_recv_next)
...
@@ -73,31 +76,38 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
...
@@ -73,31 +76,38 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
# Send tensors in both the forward and backward directions as appropriate.
ops
=
[]
if
use_ring_exchange
:
if
tensor_send_prev
is
not
None
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
send_prev_op
=
torch
.
distributed
.
P2POp
(
tensor_recv_prev
=
tensor_recv_prev
,
torch
.
distributed
.
isend
,
tensor_send_prev
,
tensor_send_next
=
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
tensor_recv_next
=
tensor_recv_next
,
ops
.
append
(
send_prev_op
)
group
=
mpu
.
get_pipeline_model_parallel_group
())
if
tensor_recv_prev
is
not
None
:
else
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
ops
=
[]
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
if
tensor_send_prev
is
not
None
:
mpu
.
get_pipeline_model_parallel_prev_rank
())
send_prev_op
=
torch
.
distributed
.
P2POp
(
ops
.
append
(
recv_prev_op
)
torch
.
distributed
.
isend
,
tensor_send_prev
,
if
tensor_send_next
is
not
None
:
mpu
.
get_pipeline_model_parallel_prev_rank
())
send_next_op
=
torch
.
distributed
.
P2POp
(
ops
.
append
(
send_prev_op
)
torch
.
distributed
.
isend
,
tensor_send_next
,
if
tensor_recv_prev
is
not
None
:
mpu
.
get_pipeline_model_parallel_next_rank
())
recv_prev_op
=
torch
.
distributed
.
P2POp
(
ops
.
append
(
send_next_op
)
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
if
tensor_recv_next
is
not
None
:
mpu
.
get_pipeline_model_parallel_prev_rank
())
recv_next_op
=
torch
.
distributed
.
P2POp
(
ops
.
append
(
recv_prev_op
)
torch
.
distributed
.
irecv
,
tensor_recv_next
,
if
tensor_send_next
is
not
None
:
mpu
.
get_pipeline_model_parallel_next_rank
())
send_next_op
=
torch
.
distributed
.
P2POp
(
ops
.
append
(
recv_next_op
)
torch
.
distributed
.
isend
,
tensor_send_next
,
if
len
(
ops
)
>
0
:
mpu
.
get_pipeline_model_parallel_next_rank
())
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
ops
.
append
(
send_next_op
)
for
req
in
reqs
:
if
tensor_recv_next
is
not
None
:
req
.
wait
()
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment