Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
89d29a02
Commit
89d29a02
authored
Apr 25, 2025
by
silencealiang
Browse files
bug fix
parent
81e19772
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
133 additions
and
43 deletions
+133
-43
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+5
-7
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+127
-35
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+1
-1
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
89d29a02
...
@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -99,7 +99,7 @@ class CoreAdaptation(MegatronAdaptationABC):
)
)
from
..core.models.gpt.gpt_model
import
(
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_forward
,
gpt_model_init
_wrapper
,
gpt_model_init
,
shared_embedding_or_output_weight
,
shared_embedding_or_output_weight
,
)
)
from
..core.models.common.language_module.language_module
import
(
from
..core.models.common.language_module.language_module
import
(
...
@@ -130,9 +130,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -130,9 +130,7 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight'
,
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight'
,
shared_embedding_or_output_weight
)
shared_embedding_or_output_weight
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init
)
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core
import
transformer_block_init_wrapper
...
@@ -152,9 +150,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -152,9 +150,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
apply_wrapper
=
True
)
#
apply_wrapper=True)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
89d29a02
...
@@ -25,45 +25,131 @@ from dcu_megatron.core.transformer.transformer_config import TransformerConfig
...
@@ -25,45 +25,131 @@ from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
def
gpt_model_init_wrapper
(
fn
):
def
gpt_model_init
(
@
wraps
(
fn
)
self
,
def
wrapper
(
self
,
*
args
,
**
kwargs
):
config
:
TransformerConfig
,
fn
(
self
,
*
args
,
**
kwargs
)
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
if
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
):
max_sequence_length
:
int
,
self
.
embedding
=
LanguageModelEmbedding
(
pre_process
:
bool
=
True
,
config
=
self
.
config
,
post_process
:
bool
=
True
,
vocab_size
=
self
.
vocab_size
,
fp16_lm_cross_entropy
:
bool
=
False
,
max_sequence_length
=
self
.
max_sequence_length
,
parallel_output
:
bool
=
True
,
position_embedding_type
=
kwargs
.
get
(
"position_embedding_type"
),
share_embeddings_and_output_weights
:
bool
=
False
,
scatter_to_sequence_parallel
=
kwargs
.
get
(
"scatter_embedding_sequence_parallel"
),
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
)
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
super
(
GPTModel
,
self
).
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
(
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
post_process
self
.
rotary_pos_emb
=
RotaryEmbedding
(
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
kv_channels
=
self
.
config
.
kv_channels
,
):
rotary_percent
=
rotary_percent
,
self
.
output_layer
=
FluxColumnParallelLinear
(
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
self
.
config
.
hidden_size
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
self
.
vocab_size
,
rotary_base
=
rotary_base
,
config
=
self
.
config
,
rope_scaling
=
rope_scaling
,
init_method
=
self
.
config
.
init_method
,
rope_scaling_factor
=
rope_scaling_factor
,
bias
=
False
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
skip_bias_add
=
False
,
)
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
# Cache for RoPE tensors which do not change between iterations.
and
self
.
share_embeddings_and_output_weights
,
self
.
rotary_pos_emb_cache
=
{}
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
# Transformer.
)
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
self
.
setup_embeddings_and_output_layer
()
if
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
):
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
# add mtp
# add mtp
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
num_nextn_predict_layers
:
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
:
ModuleSpec
=
self
.
config
.
mtp_spec
self
.
mtp_spec
=
self
.
config
.
mtp_spec
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
...
@@ -81,7 +167,7 @@ def gpt_model_init_wrapper(fn):
...
@@ -81,7 +167,7 @@ def gpt_model_init_wrapper(fn):
parallel_output
=
self
.
parallel_output
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
kwargs
.
get
(
"
seq_len_interpolation_factor
"
,
None
)
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
add_output_layer_bias
=
False
...
@@ -90,7 +176,13 @@ def gpt_model_init_wrapper(fn):
...
@@ -90,7 +176,13 @@ def gpt_model_init_wrapper(fn):
]
]
)
)
return
wrapper
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
...
...
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
View file @
89d29a02
...
@@ -144,7 +144,7 @@ class MultiTokenPredictor(MegatronModule):
...
@@ -144,7 +144,7 @@ class MultiTokenPredictor(MegatronModule):
"""Forward function of the MTP module"""
"""Forward function of the MTP module"""
# Decoder embedding.
# Decoder embedding.
decoder_input
=
embedding
(
decoder_input
=
embedding
_layer
(
input_ids
=
embed_input_ids
,
input_ids
=
embed_input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment