Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
6658158b
Commit
6658158b
authored
Mar 02, 2022
by
Vijay Korthikanti
Browse files
get sequence parallelism to work with pipeline parallelism
parent
c0f10643
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
15 deletions
+41
-15
megatron/model/transformer.py
megatron/model/transformer.py
+14
-6
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-1
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+1
-1
megatron/p2p_communication.py
megatron/p2p_communication.py
+6
-3
megatron/schedules.py
megatron/schedules.py
+15
-0
megatron/training.py
megatron/training.py
+4
-4
No files found.
megatron/model/transformer.py
View file @
6658158b
...
@@ -619,6 +619,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -619,6 +619,8 @@ class ParallelTransformer(MegatronModule):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
layer_type
=
layer_type
self
.
model_type
=
args
.
model_type
self
.
bf16
=
args
.
bf16
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
...
@@ -629,7 +631,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -629,7 +631,8 @@ class ParallelTransformer(MegatronModule):
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
self
.
distribute_checkpointed_activations
=
\
args
.
distribute_checkpointed_activations
and
not
args
.
model_parallel_memory_opt
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
...
@@ -807,9 +810,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -807,9 +810,9 @@ class ParallelTransformer(MegatronModule):
)
)
# Transpose encoder output.
# Transpose encoder output.
if
encoder_output
is
not
None
:
if
encoder_output
is
not
None
and
\
not
self
.
model_parallel_memory_opt
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
encoder_output
=
mpu
.
scatter_to_sequence_parallel_region
(
encoder_output
)
encoder_output
=
mpu
.
scatter_to_sequence_parallel_region
(
encoder_output
)
...
@@ -835,10 +838,15 @@ class ParallelTransformer(MegatronModule):
...
@@ -835,10 +838,15 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
model_parallel_memory_opt
:
if
self
.
layer_type
==
LayerType
.
encoder
and
\
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
)
self
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
self
.
model_parallel_memory_opt
:
output
=
hidden_states
else
:
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
)
output
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
output
=
hidden_states
output
=
hidden_states
...
...
megatron/mpu/__init__.py
View file @
6658158b
...
@@ -61,7 +61,7 @@ from .mappings import reduce_from_tensor_model_parallel_region
...
@@ -61,7 +61,7 @@ from .mappings import reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_sequence_parallel_region
from
.mappings
import
scatter_to_sequence_parallel_region
from
.mappings
import
gather_from_seq
e
uence_parallel_region
from
.mappings
import
gather_from_sequence_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
...
...
megatron/mpu/mappings.py
View file @
6658158b
...
@@ -278,7 +278,7 @@ def scatter_to_sequence_parallel_region(input_):
...
@@ -278,7 +278,7 @@ def scatter_to_sequence_parallel_region(input_):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_seq
e
uence_parallel_region
(
input_
):
def
gather_from_sequence_parallel_region
(
input_
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
)
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
)
...
...
megatron/p2p_communication.py
View file @
6658158b
...
@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
override_scatter_gather_tensors_in_pipeline
=
False
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model_parallel_memory_opt
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
tensor_chunk_shape
=
tensor_chunk_shape
//
\
...
@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model_parallel_memory_opt
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
model_parallel_memory_opt
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
...
megatron/schedules.py
View file @
6658158b
...
@@ -514,6 +514,21 @@ def get_tensor_shapes(rank, model_type):
...
@@ -514,6 +514,21 @@ def get_tensor_shapes(rank, model_type):
# Otherwise, send one tensor (pre-transpose).
# Otherwise, send one tensor (pre-transpose).
args
=
get_args
()
args
=
get_args
()
tensor_shapes
=
[]
tensor_shapes
=
[]
if
args
.
model_parallel_memory_opt
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
return
tensor_shapes
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
# If next rank is after split, then need transpose for encoder_hidden_state.
# If next rank is after split, then need transpose for encoder_hidden_state.
...
...
megatron/training.py
View file @
6658158b
...
@@ -421,16 +421,16 @@ def train_step(forward_step_func, data_iterator,
...
@@ -421,16 +421,16 @@ def train_step(forward_step_func, data_iterator,
# All-reduce layernorm parameters across model parallel nodes
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
# when sequence parallelism is used
if
args
.
get_tensor_model_parallel_world_size
>
1
and
\
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
args
.
model_parallel_memory_opt
:
args
.
model_parallel_memory_opt
:
grads
=
[]
grads
=
[]
for
model_module
in
model
:
for
model_module
in
model
:
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
model_module
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model_module
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
for
param
in
unwrapped_model
.
parameters
():
for
param
in
unwrapped_model
.
parameters
():
if
param
.
get
_
attr
(
'sequence_parallel'
,
False
):
if
getattr
(
param
,
'sequence_parallel'
,
False
):
assert
param
.
requires_grad
and
param
.
grad
is
not
None
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
param
.
grad
.
data
)
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_tensor_model_parallel_world_size
()
coalesced
/=
mpu
.
get_tensor_model_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment