Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
eec218d8
Commit
eec218d8
authored
Mar 30, 2022
by
Vijay Korthikanti
Browse files
sequence parallelism for embedding dropout and last linear layer + memory optimizations
parent
53718d4c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
93 additions
and
81 deletions
+93
-81
megatron/model/language_model.py
megatron/model/language_model.py
+20
-3
megatron/model/transformer.py
megatron/model/transformer.py
+31
-47
megatron/mpu/layers.py
megatron/mpu/layers.py
+20
-9
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+9
-2
megatron/schedules.py
megatron/schedules.py
+10
-18
megatron/training.py
megatron/training.py
+2
-1
pretrain_gpt.py
pretrain_gpt.py
+1
-1
No files found.
megatron/model/language_model.py
View file @
eec218d8
...
...
@@ -37,7 +37,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
input_parallel
=
input_
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
model_parallel
and
not
args
.
model_parallel_memory_opt
else
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
...
...
@@ -46,7 +46,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
None
)
async_grad_allreduce
,
args
.
model_parallel_memory_opt
)
# Gather if needed.
if
parallel_output
:
...
...
@@ -170,6 +170,8 @@ class Embedding(MegatronModule):
else
:
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
...
...
@@ -211,8 +213,23 @@ class Embedding(MegatronModule):
else
:
assert
self
.
tokentype_embeddings
is
None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
if
self
.
model_parallel_memory_opt
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
...
...
megatron/model/transformer.py
View file @
eec218d8
...
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_
args
from
megatron
import
get_
timers
,
get_args
,
print_rank_last
,
print_rank_0
from
megatron
import
mpu
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
...
...
@@ -27,6 +27,8 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
_MATMUL_INPUT
=
None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
...
...
@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
...
...
@@ -189,7 +190,18 @@ class CoreAttention(MegatronModule):
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
global
_MATMUL_INPUT
if
_MATMUL_INPUT
is
None
:
_MATMUL_INPUT
=
torch
.
empty
(
args
.
micro_batch_size
*
self
.
num_attention_heads_per_partition
,
args
.
seq_length
,
args
.
seq_length
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
())
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
...
...
@@ -230,16 +242,19 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
#matmul_result = torch.empty(
# output_size[0]*output_size[1],
# output_size[2],
# output_size[3],
# dtype=query_layer.dtype,
# device=torch.cuda.current_device())
global
_MATMUL_INPUT
matmul_input
=
_MATMUL_INPUT
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_
resul
t
,
matmul_
inpu
t
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
...
...
@@ -838,6 +853,7 @@ class ParallelTransformer(MegatronModule):
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
...
...
@@ -869,25 +885,12 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
# Checks.
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
'inference does not work with activation checkpointing'
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
scatter_to_sequence_parallel_region
(
hidden_states
)
else
:
if
not
self
.
pre_process
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
...
...
@@ -908,17 +911,10 @@ class ParallelTransformer(MegatronModule):
# is called here to be future-proof and corner-case-proof.
hidden_states
=
mpu
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# Transpose encoder output.
if
encoder_output
is
not
None
and
\
not
self
.
model_parallel_memory_opt
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
encoder_output
=
mpu
.
scatter_to_sequence_parallel_region
(
encoder_output
)
if
self
.
model_parallel_memory_opt
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
# Forward pass.
...
...
@@ -928,6 +924,7 @@ class ParallelTransformer(MegatronModule):
encoder_output
,
enc_dec_attn_mask
)
else
:
total
=
0
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
...
...
@@ -936,6 +933,7 @@ class ParallelTransformer(MegatronModule):
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
else
:
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
...
...
@@ -955,20 +953,6 @@ class ParallelTransformer(MegatronModule):
# Final layer norm.
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
layer_type
==
LayerType
.
encoder
and
\
self
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
self
.
model_parallel_memory_opt
:
output
=
hidden_states
else
:
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
)
output
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
output
=
hidden_states
return
output
return
hidden_states
megatron/mpu/layers.py
View file @
eec218d8
...
...
@@ -41,11 +41,12 @@ from .utils import split_tensor_along_last_dim
from
.utils
import
VocabUtility
from
megatron
import
get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
_TOTAL_INPUT
=
None
_SUB_GRAD_INPUT
=
None
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
...
...
@@ -221,9 +222,11 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
total_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global
_TOTAL_INPUT
total_input
=
_TOTAL_INPUT
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
())
...
...
@@ -246,9 +249,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
total_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
#total_input = torch.empty(dim_size, dtype=input.dtype,
# device=torch.cuda.current_device(),
# requires_grad=False)
global
_TOTAL_INPUT
total_input
=
_TOTAL_INPUT
handle
=
torch
.
distributed
.
_all_gather_base
(
total_input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
...
...
@@ -279,8 +285,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
...
...
@@ -390,6 +396,11 @@ class ColumnParallelLinear(torch.nn.Module):
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
model_parallel_memory_opt
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
global
_TOTAL_INPUT
if
_TOTAL_INPUT
is
None
:
_TOTAL_INPUT
=
torch
.
empty
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
forward
(
self
,
input_
):
...
...
megatron/optimizer/optimizer.py
View file @
eec218d8
...
...
@@ -264,6 +264,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
#state = self.optimizer.state[main_param]
#if len(state) == 0:
# # Exponential moving average of gradient values
# state['exp_avg'] = torch.zeros_like(main_param.data)
# # Exponential moving average of squared gradient values
# state['exp_avg_sq'] = torch.zeros_like(main_param.data)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
...
...
@@ -284,8 +291,8 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
#
self.optimizer.load_state_dict(self.optimizer.state_dict())
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
...
...
megatron/schedules.py
View file @
eec218d8
...
...
@@ -517,30 +517,22 @@ def get_tensor_shapes(rank, model_type):
if
args
.
model_parallel_memory_opt
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
else
:
seq_length
=
args
.
seq_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
args
.
model_parallel_memory_opt
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
decoder_
seq_length
=
args
.
decoder_seq_length
return
tensor_shapes
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
# If next rank is after split, then need transpose for encoder_hidden_state.
if
mpu
.
is_pipeline_stage_before_split
(
rank
+
1
):
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
tensor_shapes
.
append
((
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
return
tensor_shapes
...
...
megatron/training.py
View file @
eec218d8
...
...
@@ -432,7 +432,8 @@ def train_step(forward_step_func, data_iterator,
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_tensor_model_parallel_world_size
()
#TODO VIJAY
#coalesced /= mpu.get_tensor_model_parallel_world_size()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
...
...
pretrain_gpt.py
View file @
eec218d8
...
...
@@ -59,7 +59,7 @@ def get_batch(data_iterator):
# Unpack.
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
contiguous
()
labels
=
tokens_
[:,
1
:].
transpose
(
0
,
1
).
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
# Get the masks and postition ids.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment