Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
43770f8e
Commit
43770f8e
authored
May 06, 2025
by
dongcl
Browse files
bug fix
parent
b85974a6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
13 deletions
+54
-13
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+4
-3
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+10
-8
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+34
-0
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+6
-2
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
43770f8e
...
@@ -89,9 +89,12 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -89,9 +89,12 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
pass
def
patch_core_models
(
self
):
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_forward
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
gpt_model_forward
)
...
@@ -171,8 +174,6 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -171,8 +174,6 @@ class CoreAdaptation(MegatronAdaptationABC):
FluxRowParallelLinear
)
FluxRowParallelLinear
)
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
get_gpt_layer_with_flux_spec
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers"
,
FluxColumnParallelLinear
)
def
patch_pipeline_parallel
(
self
):
def
patch_pipeline_parallel
(
self
):
pass
pass
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
43770f8e
...
@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
...
@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
MLASelfAttentionSubmodules
,
)
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.torch_norm
import
L2Norm
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
from
megatron.core.transformer.transformer_layer
import
(
...
@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
...
@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear
,
FluxColumnParallelLinear
,
FluxRowParallelLinear
FluxRowParallelLinear
)
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_layer_with_flux_spec
(
def
get_gpt_layer_with_flux_spec
(
...
@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
...
@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_l2_norm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
)
->
ModuleSpec
:
"""Use this spec to use flux modules (required for fp8 training).
"""Use this spec to use flux modules (required for fp8 training).
...
@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
...
@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns:
Returns:
ModuleSpec: Module specification with flux modules
ModuleSpec: Module specification with flux modules
...
@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
...
@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
)
)
if
multi_latent_attention
:
if
multi_latent_attention
:
assert
qk_l2_norm
is
False
,
"qk_l2_norm is not supported with MLA."
return
ModuleSpec
(
return
ModuleSpec
(
module
=
TransformerLayer
,
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
submodules
=
TransformerLayerSubmodules
(
...
@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
...
@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv
=
FluxColumnParallelLinear
,
linear_qkv
=
FluxColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
FluxRowParallelLinear
,
linear_proj
=
FluxRowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
q_layernorm
=
(
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
k_layernorm
=
(
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
),
),
),
),
self_attn_bda
=
get_bias_dropout_add
,
self_attn_bda
=
get_bias_dropout_add
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
43770f8e
...
@@ -2,14 +2,48 @@ from collections import OrderedDict
...
@@ -2,14 +2,48 @@ from collections import OrderedDict
from
typing
import
Optional
from
typing
import
Optional
from
functools
import
wraps
from
functools
import
wraps
import
os
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
WrappedTensor
,
deprecate_inference_params
from
megatron.core.utils
import
WrappedTensor
,
deprecate_inference_params
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# Output
if
(
(
self
.
post_process
or
self
.
mtp_process
)
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
):
self
.
output_layer
=
FluxColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
return
wrapper
def
gpt_model_forward
(
def
gpt_model_forward
(
self
,
self
,
...
...
dcu_megatron/core/tensor_parallel/layers.py
View file @
43770f8e
...
@@ -24,7 +24,7 @@ from megatron.core.tensor_parallel.mappings import (
...
@@ -24,7 +24,7 @@ from megatron.core.tensor_parallel.mappings import (
)
)
from
megatron.core.tensor_parallel
import
(
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
)
)
from
megatron.core.tensor_parallel.layers
import
(
from
megatron.core.tensor_parallel.layers
import
(
custom_fwd
,
custom_fwd
,
...
@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_comm_buffer_name
:
str
=
None
,
# Not used
disable_grad_reduce
:
bool
=
False
,
disable_grad_reduce
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
input_size
=
input_size
,
input_size
=
input_size
,
...
@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
=
is_expert
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
disable_grad_reduce
=
disable_grad_reduce
,
disable_grad_reduce
=
disable_grad_reduce
,
tp_group
=
tp_group
,
)
)
# flux params
# flux params
...
@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test
:
bool
=
False
,
keep_master_weight_for_test
:
bool
=
False
,
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
super
(
FluxRowParallelLinear
,
self
).
__init__
(
super
(
FluxRowParallelLinear
,
self
).
__init__
(
...
@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride
=
stride
,
stride
=
stride
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
is_expert
=
is_expert
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
)
# flux params
# flux params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment