Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a676bc2d
Commit
a676bc2d
authored
Jul 27, 2021
by
Deepak Narayanan
Browse files
Merge branch 'main_p2p' into 'main'
See merge request ADLR/megatron-lm!293
parents
3db6517a
1dccefd8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
12 deletions
+44
-12
megatron/p2p_communication.py
megatron/p2p_communication.py
+44
-12
No files found.
megatron/p2p_communication.py
View file @
a676bc2d
...
@@ -22,7 +22,9 @@ from megatron import mpu
...
@@ -22,7 +22,9 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
use_ring_exchange
=
False
,
tensor_shape
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
):
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
...
@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
next rank.
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns:
Returns:
(tensor_recv_prev, tensor_recv_next)
(tensor_recv_prev, tensor_recv_next)
"""
"""
...
@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
# if needed.
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
tensor_shape
is
None
:
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
...
@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
dtype
=
args
.
params_dtype
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
dtype
=
torch
.
float
requires_grad
=
True
if
dtype_
is
not
None
:
dtype
=
dtype_
requires_grad
=
False
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
if
recv_next
:
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
):
def
recv_forward
(
tensor_shape
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
,
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
...
@@ -135,7 +157,11 @@ def recv_forward(timers=None):
...
@@ -135,7 +157,11 @@ def recv_forward(timers=None):
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
timers
(
'forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
...
@@ -158,8 +184,11 @@ def recv_backward(timers=None):
...
@@ -158,8 +184,11 @@ def recv_backward(timers=None):
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
):
def
send_forward
(
output_tensor
,
timers
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
timers
(
'forward-send'
).
start
()
...
@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
...
@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
)
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
timers
(
'forward-send'
).
stop
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment