Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1dccefd8
Commit
1dccefd8
authored
Jul 27, 2021
by
Mostofa Patwary
Committed by
Deepak Narayanan
Jul 27, 2021
Browse files
Make it possible to pass in tensor shapes to communication methods in p2p_communication.py
parent
3db6517a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
12 deletions
+44
-12
megatron/p2p_communication.py
megatron/p2p_communication.py
+44
-12
No files found.
megatron/p2p_communication.py
View file @
1dccefd8
...
@@ -22,7 +22,9 @@ from megatron import mpu
...
@@ -22,7 +22,9 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
use_ring_exchange
=
False
,
tensor_shape
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
):
"""Communicate tensors between stages. Used as helper method in other
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
communication methods that are used in megatron/schedules.py.
...
@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
next rank.
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns:
Returns:
(tensor_recv_prev, tensor_recv_next)
(tensor_recv_prev, tensor_recv_next)
"""
"""
...
@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
# if needed.
tensor_recv_prev
=
None
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_next
=
None
if
tensor_shape
is
None
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
...
@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
dtype
=
args
.
params_dtype
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
dtype
=
torch
.
float
requires_grad
=
True
if
dtype_
is
not
None
:
dtype
=
dtype_
requires_grad
=
False
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
if
recv_next
:
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
):
def
recv_forward
(
tensor_shape
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
,
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
...
@@ -135,7 +157,11 @@ def recv_forward(timers=None):
...
@@ -135,7 +157,11 @@ def recv_forward(timers=None):
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_prev
=
True
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
timers
(
'forward-recv'
).
stop
()
return
input_tensor
return
input_tensor
...
@@ -158,8 +184,11 @@ def recv_backward(timers=None):
...
@@ -158,8 +184,11 @@ def recv_backward(timers=None):
return
output_tensor_grad
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
):
def
send_forward
(
output_tensor
,
timers
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
timers
(
'forward-send'
).
start
()
...
@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
...
@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_prev
=
False
,
recv_next
=
False
)
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
timers
(
'forward-send'
).
stop
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment