Commit 9eb8683b authored by dongcl's avatar dongcl
Browse files

Merge branch 'main' into megatron_v0.11.0

parents 6f016785 be9a69d7
...@@ -123,7 +123,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -123,7 +123,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block # Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__', MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
...@@ -133,9 +133,9 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -133,9 +133,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# Transformer config # Transformer config
MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig', MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig',
TransformerConfig) TransformerConfigPatch)
MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig', MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig',
MLATransformerConfig) MLATransformerConfigPatch)
# Moe # Moe
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity', MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
...@@ -154,18 +154,19 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -154,18 +154,19 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_extentions(self): def patch_core_extentions(self):
import transformer_engine as te import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear from megatron.core.extensions.transformer_engine import TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__', MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init) TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')): if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,) TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,)
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import ColumnParallelLinearPatch, RowParallelLinearPatch, parallel_linear_init_wrapper
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
...@@ -186,6 +187,19 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -186,6 +187,19 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod, staticmethod,
apply_wrapper=True) apply_wrapper=True)
# flux
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
......
from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward
import os import os
import dataclasses import dataclasses
import transformer_engine as te
from typing import Any, Optional from typing import Any, Optional
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
...@@ -19,7 +20,8 @@ from megatron.core.parallel_state import ( ...@@ -19,7 +20,8 @@ from megatron.core.parallel_state import (
) )
def te_dot_product_attention_init( class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
self, self,
config: TransformerConfig, config: TransformerConfig,
layer_number: int, layer_number: int,
...@@ -30,7 +32,7 @@ def te_dot_product_attention_init( ...@@ -30,7 +32,7 @@ def te_dot_product_attention_init(
k_channels: Optional[int] = None, k_channels: Optional[int] = None,
v_channels: Optional[int] = None, v_channels: Optional[int] = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
): ):
self.config = config self.config = config
self.te_forward_mask_type = False self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd' self.qkv_format: str = 'sbhd'
......
from .layers import (
parallel_linear_init_wrapper
ColumnParallelLinearPatch,
RowParallelLinearPatch,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
This diff is collapsed.
...@@ -165,3 +165,12 @@ def _add_mtp_args(parser): ...@@ -165,3 +165,12 @@ def _add_mtp_args(parser):
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.') help='Main model share embedding and output weight with mtp layer.')
return parser return parser
def _add_flux_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--use-flux', action='store_true', default=False,
help='If set, flux will be used in ColumnParallelLinear and RowParallelLinear')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment