Commit be9a69d7 authored by dongcl's avatar dongcl
Browse files

集成flux

parent 0b2b5417
...@@ -112,7 +112,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -112,7 +112,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block # Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__', MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
...@@ -122,9 +122,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -122,9 +122,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# Transformer config # Transformer config
MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig', MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig',
TransformerConfig) TransformerConfigPatch)
MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig', MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig',
MLATransformerConfig) MLATransformerConfigPatch)
# Moe # Moe
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
...@@ -153,8 +153,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -153,8 +153,9 @@ class CoreAdaptation(MegatronAdaptationABC):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,) TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,)
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import ColumnParallelLinearPatch, RowParallelLinearPatch, parallel_linear_init_wrapper
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
...@@ -170,6 +171,19 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -170,6 +171,19 @@ class CoreAdaptation(MegatronAdaptationABC):
torch.compile(mode='max-autotune-no-cudagraphs'), torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True) apply_wrapper=True)
# flux
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
......
from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward
from .layers import (
parallel_linear_init_wrapper
ColumnParallelLinearPatch,
RowParallelLinearPatch,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
This diff is collapsed.
...@@ -525,4 +525,13 @@ def _add_mtp_args(parser): ...@@ -525,4 +525,13 @@ def _add_mtp_args(parser):
help='Multi-Token prediction recompute layer') help='Multi-Token prediction recompute layer')
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.') help='Main model share embedding and output weight with mtp layer.')
return parser return parser
\ No newline at end of file
def _add_flux_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--use-flux', action='store_true', default=False,
help='If set, flux will be used in ColumnParallelLinear and RowParallelLinear')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment