Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
78cf869f
Commit
78cf869f
authored
Feb 25, 2021
by
Deepak Narayanan
Browse files
Get PyTorch batched communication API working for interleaved schedule
parent
3cbf7547
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
58 deletions
+40
-58
megatron/p2p_communication.py
megatron/p2p_communication.py
+38
-56
megatron/schedules.py
megatron/schedules.py
+2
-2
No files found.
megatron/p2p_communication.py
View file @
78cf869f
...
...
@@ -21,8 +21,7 @@ from megatron import get_args
from
megatron
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
...
...
@@ -35,8 +34,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
...
...
@@ -76,34 +73,28 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
else
:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
...
...
@@ -123,7 +114,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
,
use_ring_exchange
=
False
):
def
recv_forward
(
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
...
...
@@ -134,14 +125,13 @@ def recv_forward(timers=None, use_ring_exchange=False):
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
timers
=
None
,
use_ring_exchange
=
False
):
def
recv_backward
(
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
...
...
@@ -152,14 +142,13 @@ def recv_backward(timers=None, use_ring_exchange=False):
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
use_ring_exchange
=
use_ring_exchange
)
recv_next
=
True
)
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_forward
(
output_tensor
,
timers
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
...
...
@@ -168,13 +157,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_backward
(
input_tensor_grad
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
...
...
@@ -183,13 +171,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
...
...
@@ -200,14 +187,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
use_ring_exchange
=
use_ring_exchange
)
recv_next
=
True
)
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
...
...
@@ -218,8 +204,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
...
...
@@ -233,8 +218,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
use_ring_exchange
=
True
)
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
...
...
@@ -248,8 +232,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
use_ring_exchange
=
True
)
recv_next
=
recv_next
)
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
...
...
@@ -265,8 +248,7 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
use_ring_exchange
=
True
)
recv_next
=
recv_next
)
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
View file @
78cf869f
...
...
@@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
,
use_ring_exchange
=
True
))
p2p_communication
.
recv_forward
(
timers
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
...
...
@@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
,
use_ring_exchange
=
True
))
p2p_communication
.
recv_backward
(
timers
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment