Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
1dc8bc8a
Commit
1dc8bc8a
authored
Apr 25, 2025
by
dongcl
Browse files
fix the bug related to parameter sharing
parent
5b1e05ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
136 deletions
+37
-136
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+15
-4
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+1
-13
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+16
-86
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+5
-33
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
1dc8bc8a
...
...
@@ -100,7 +100,11 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init_wrapper
,
shared_embedding_or_mtp_embedding_weight
shared_embedding_or_output_weight
,
)
from
..core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
)
from
..training.utils
import
get_batch_on_this_tp_rank
...
...
@@ -115,14 +119,21 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
setup_embeddings_and_output_layer
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight'
,
shared_embedding_or_output_weight
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
from
megatron.core.models.gpt.gpt_model
import
GPTModel
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
1dc8bc8a
...
...
@@ -28,22 +28,10 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'num_nextn_predict_layers'
,
0
)
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
if
(
hasattr
(
model_module
,
"share_mtp_embedding_and_output_weight"
)
and
model_module
.
share_mtp_embedding_and_output_weight
and
config
.
num_nextn_predict_layers
>
0
):
weight
=
model_module
.
shared_embedding_or_mtp_embedding_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
dcu_megatron/core/models/gpt/gpt_model.py
View file @
1dc8bc8a
...
...
@@ -30,6 +30,15 @@ def gpt_model_init_wrapper(fn):
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
if
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
):
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
kwargs
.
get
(
"position_embedding_type"
),
scatter_to_sequence_parallel
=
kwargs
.
get
(
"scatter_embedding_sequence_parallel"
),
)
if
(
self
.
post_process
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
...
...
@@ -55,7 +64,6 @@ def gpt_model_init_wrapper(fn):
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
:
ModuleSpec
=
self
.
config
.
mtp_spec
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
...
...
@@ -74,7 +82,6 @@ def gpt_model_init_wrapper(fn):
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
kwargs
.
get
(
"seq_len_interpolation_factor"
,
None
),
share_mtp_embedding_and_output_weight
=
self
.
share_mtp_embedding_and_output_weight
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
...
...
@@ -83,95 +90,22 @@ def gpt_model_init_wrapper(fn):
]
)
if
self
.
pre_process
or
self
.
post_process
:
setup_mtp_embeddings
(
self
)
return
wrapper
def
shared_embedding_or_
mtp_embedding
_weight
(
self
)
->
Tensor
:
"""Gets the em
b
edding weight when share embedding and
mtp embedding
weights set to True.
def
shared_embedding_or_
output
_weight
(
self
)
->
Tensor
:
"""Gets the emedding weight
or output logit weights
when share embedding and
output
weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
"""
assert
self
.
num_nextn_predict_layers
>
0
if
self
.
pre_process
:
if
self
.
pre_process
or
(
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
)):
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
mtp_layers
[
0
].
embedding
.
word_embeddings
.
weight
return
self
.
output_layer
.
weight
return
None
def
setup_mtp_embeddings
(
self
):
"""
Share embedding layer in mtp layer.
"""
if
self
.
pre_process
:
self
.
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
# Set `is_embedding_or_output_parameter` attribute.
for
i
in
range
(
self
.
num_nextn_predict_layers
):
if
self
.
post_process
and
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
is
not
None
:
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
if
not
self
.
share_mtp_embedding_and_output_weight
:
return
if
self
.
pre_process
and
self
.
post_process
:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self
.
shared_embedding_or_mtp_embedding_weight
().
zero_out_wgrad
=
True
return
if
self
.
pre_process
and
not
self
.
post_process
:
assert
parallel_state
.
is_pipeline_first_stage
()
self
.
shared_embedding_or_mtp_embedding_weight
().
shared_embedding
=
True
if
self
.
post_process
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
for
i
in
range
(
self
.
num_nextn_predict_layers
):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
shared
=
True
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
shared_embedding
=
True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
parallel_state
.
is_rank_in_embedding_group
():
weight
=
self
.
shared_embedding_or_mtp_embedding_weight
()
weight
.
data
=
weight
.
data
.
cuda
()
torch
.
distributed
.
all_reduce
(
weight
.
data
,
group
=
parallel_state
.
get_embedding_group
()
)
elif
not
getattr
(
LanguageModule
,
"embedding_warning_printed"
,
False
):
logging
.
getLogger
(
__name__
).
warning
(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule
.
embedding_warning_printed
=
True
def
slice_inputs
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
):
if
self
.
num_nextn_predict_layers
==
0
:
return
(
...
...
@@ -317,11 +251,6 @@ def gpt_model_forward(
loss
=
0
# Multi token prediction module
if
self
.
num_nextn_predict_layers
and
self
.
training
:
if
not
self
.
share_embeddings_and_output_weights
and
self
.
share_mtp_embedding_and_output_weight
:
output_weight
=
self
.
output_layer
.
weight
output_weight
.
zero_out_wgrad
=
True
embedding_weight
=
self
.
shared_embedding_or_mtp_embedding_weight
()
if
self
.
share_mtp_embedding_and_output_weight
else
None
mtp_hidden_states
=
hidden_states
for
i
in
range
(
self
.
num_nextn_predict_layers
):
mtp_hidden_states
,
mtp_loss
=
self
.
mtp_layers
[
i
](
...
...
@@ -333,7 +262,8 @@ def gpt_model_forward(
inference_params
,
packed_seq_params
,
extra_block_kwargs
,
embeding_weight
=
embedding_weight
,
embedding_layer
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
)
...
...
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
View file @
1dc8bc8a
...
...
@@ -46,7 +46,6 @@ class MultiTokenPredictor(MegatronModule):
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
share_mtp_embedding_and_output_weight
=
True
,
recompute_mtp_norm
=
False
,
recompute_mtp_layer
=
False
,
add_output_layer_bias
=
False
...
...
@@ -65,20 +64,10 @@ class MultiTokenPredictor(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
position_embedding_type
=
position_embedding_type
# share with main model
self
.
share_mtp_embedding_and_output_weight
=
share_mtp_embedding_and_output_weight
self
.
recompute_layer_norm
=
recompute_mtp_norm
self
.
recompute_mtp_layer
=
recompute_mtp_layer
self
.
add_output_layer_bias
=
add_output_layer_bias
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
self
.
position_embedding_type
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_mtp_embedding_and_output_weight
)
if
self
.
position_embedding_type
==
'rope'
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
...
...
@@ -138,23 +127,6 @@ class MultiTokenPredictor(MegatronModule):
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
share_mtp_embedding_and_output_weight
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
def
forward
(
self
,
hidden_input_ids
:
Tensor
,
...
...
@@ -165,16 +137,16 @@ class MultiTokenPredictor(MegatronModule):
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
embeding_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
output_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_layer
=
None
,
output_layer
=
None
,
output_weight
=
None
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input
=
self
.
embedding
(
decoder_input
=
embedding
(
input_ids
=
embed_input_ids
,
position_ids
=
position_ids
,
weight
=
embeding_weight
,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
...
...
@@ -251,7 +223,7 @@ class MultiTokenPredictor(MegatronModule):
else
:
finalnorm_output
=
hidden_states
logits
,
_
=
self
.
output_layer
(
finalnorm_output
,
weight
=
output_weight
)
logits
,
_
=
output_layer
(
finalnorm_output
,
weight
=
output_weight
)
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
.
discard_output
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment