Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
43770f8e
Commit
43770f8e
authored
May 06, 2025
by
dongcl
Browse files
bug fix
parent
b85974a6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
13 deletions
+54
-13
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+4
-3
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+10
-8
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+34
-0
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+6
-2
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
43770f8e
...
...
@@ -89,9 +89,12 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_forward
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
...
...
@@ -171,8 +174,6 @@ class CoreAdaptation(MegatronAdaptationABC):
FluxRowParallelLinear
)
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers"
,
FluxColumnParallelLinear
)
def
patch_pipeline_parallel
(
self
):
pass
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
43770f8e
...
...
@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.torch_norm
import
L2Norm
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
...
...
@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_layer_with_flux_spec
(
...
...
@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_l2_norm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec to use flux modules (required for fp8 training).
...
...
@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns:
ModuleSpec: Module specification with flux modules
...
...
@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
)
if
multi_latent_attention
:
assert
qk_l2_norm
is
False
,
"qk_l2_norm is not supported with MLA."
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
...
...
@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv
=
FluxColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
FluxRowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
q_layernorm
=
(
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
k_layernorm
=
(
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
),
),
self_attn_bda
=
get_bias_dropout_add
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
43770f8e
...
...
@@ -2,14 +2,48 @@ from collections import OrderedDict
from
typing
import
Optional
from
functools
import
wraps
import
os
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
WrappedTensor
,
deprecate_inference_params
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# Output
if
(
(
self
.
post_process
or
self
.
mtp_process
)
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
):
self
.
output_layer
=
FluxColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
return
wrapper
def
gpt_model_forward
(
self
,
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
43770f8e
...
...
@@ -24,7 +24,7 @@ from megatron.core.tensor_parallel.mappings import (
)
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
)
from
megatron.core.tensor_parallel.layers
import
(
custom_fwd
,
...
...
@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
disable_grad_reduce
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
input_size
=
input_size
,
...
...
@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
disable_grad_reduce
=
disable_grad_reduce
,
tp_group
=
tp_group
,
)
# flux params
...
...
@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
super
(
FluxRowParallelLinear
,
self
).
__init__
(
...
...
@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride
=
stride
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
# flux params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment