Commit 9eb8683b authored by dongcl's avatar dongcl
Browse files

Merge branch 'main' into megatron_v0.11.0

parents 6f016785 be9a69d7
......@@ -123,7 +123,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
......@@ -133,9 +133,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# Transformer config
MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig',
TransformerConfig)
TransformerConfigPatch)
MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig',
MLATransformerConfig)
MLATransformerConfigPatch)
# Moe
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
......@@ -154,18 +154,19 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_extentions(self):
import transformer_engine as te
from ..core.extensions.transformer_engine import te_dot_product_attention_init
from ..core.extensions.transformer_engine import TEDotProductAttentionPatch
from megatron.core.extensions.transformer_engine import TEGroupedLinear
MegatronAdaptation.register('megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__',
te_dot_product_attention_init)
TEDotProductAttentionPatch.__init__)
if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')):
TEGroupedLinear.__bases__ = (te.pytorch.BatchLinear,)
def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import ColumnParallelLinearPatch, RowParallelLinearPatch, parallel_linear_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
......@@ -186,6 +187,19 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod,
apply_wrapper=True)
# flux
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.forward",
ColumnParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.__init__",
parallel_linear_init_wrapper,
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......
from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward
import os
import dataclasses
import transformer_engine as te
from typing import Any, Optional
from packaging.version import Version as PkgVersion
......@@ -19,7 +20,8 @@ from megatron.core.parallel_state import (
)
def te_dot_product_attention_init(
class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
def __init__(
self,
config: TransformerConfig,
layer_number: int,
......@@ -30,7 +32,7 @@ def te_dot_product_attention_init(
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
):
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
......
from .layers import (
parallel_linear_init_wrapper
ColumnParallelLinearPatch,
RowParallelLinearPatch,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
)
\ No newline at end of file
This diff is collapsed.
......@@ -165,3 +165,12 @@ def _add_mtp_args(parser):
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.')
return parser
def _add_flux_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--use-flux', action='store_true', default=False,
help='If set, flux will be used in ColumnParallelLinear and RowParallelLinear')
group.add_argument('--flux-transpose-weight', action='store_true', default=False,
help='Whether to transpose weight when using flux kernel')
return parser
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment