Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
be9a69d7
Commit
be9a69d7
authored
Apr 10, 2025
by
dongcl
Browse files
集成flux
parent
0b2b5417
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
893 additions
and
913 deletions
+893
-913
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+18
-4
dcu_megatron/core/__init__.py
dcu_megatron/core/__init__.py
+0
-1
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+7
-0
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+847
-2
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+11
-905
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+10
-1
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
be9a69d7
...
@@ -112,7 +112,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -112,7 +112,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
,
transformer_block_forward
from
..core
import
transformer_block_init_wrapper
,
transformer_block_forward
from
..core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
from
..core.transformer.transformer_config
import
TransformerConfig
Patch
,
MLATransformerConfig
Patch
# Transformer block
# Transformer block
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
...
@@ -122,9 +122,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -122,9 +122,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# Transformer config
# Transformer config
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.TransformerConfig'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.TransformerConfig'
,
TransformerConfig
)
TransformerConfig
Patch
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.MLATransformerConfig'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_config.MLATransformerConfig'
,
MLATransformerConfig
)
MLATransformerConfig
Patch
)
# Moe
# Moe
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
...
@@ -153,8 +153,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -153,8 +153,9 @@ class CoreAdaptation(MegatronAdaptationABC):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core.tensor_parallel
import
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
parallel_linear_init_wrapper
# VocabParallelEmbedding
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
...
@@ -170,6 +171,19 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -170,6 +171,19 @@ class CoreAdaptation(MegatronAdaptationABC):
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# flux
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward"
,
ColumnParallelLinearPatch
.
forward
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.__init__"
,
parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.forward"
,
RowParallelLinearPatch
.
forward
)
def
patch_training
(
self
):
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
...
...
dcu_megatron/core/__init__.py
View file @
be9a69d7
from
.tensor_parallel.layers
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
.transformer.transformer_block
import
transformer_block_init_wrapper
,
transformer_block_forward
from
.transformer.transformer_block
import
transformer_block_init_wrapper
,
transformer_block_forward
dcu_megatron/core/tensor_parallel/__init__.py
0 → 100644
View file @
be9a69d7
from
.layers
import
(
parallel_linear_init_wrapper
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
be9a69d7
This diff is collapsed.
Click to expand it.
dcu_megatron/core/transformer/transformer_config.py
View file @
be9a69d7
This diff is collapsed.
Click to expand it.
dcu_megatron/training/arguments.py
View file @
be9a69d7
...
@@ -525,4 +525,13 @@ def _add_mtp_args(parser):
...
@@ -525,4 +525,13 @@ def _add_mtp_args(parser):
help
=
'Multi-Token prediction recompute layer'
)
help
=
'Multi-Token prediction recompute layer'
)
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Main model share embedding and output weight with mtp layer.'
)
help
=
'Main model share embedding and output weight with mtp layer.'
)
return
parser
return
parser
\ No newline at end of file
def
_add_flux_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--use-flux'
,
action
=
'store_true'
,
default
=
False
,
help
=
'If set, flux will be used in ColumnParallelLinear and RowParallelLinear'
)
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to transpose weight when using flux kernel'
)
return
parser
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment