Commit 1f7b14ab authored by sdwldchl's avatar sdwldchl
Browse files

rewrite mtp

parent 89d29a02
...@@ -5,6 +5,8 @@ import types ...@@ -5,6 +5,8 @@ import types
import argparse import argparse
import torch import torch
from .adaptor_arguments import get_adaptor_args
class MegatronAdaptation: class MegatronAdaptation:
""" """
...@@ -21,6 +23,15 @@ class MegatronAdaptation: ...@@ -21,6 +23,15 @@ class MegatronAdaptation:
for adaptation in [CoreAdaptation(), LegacyAdaptation()]: for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute() adaptation.execute()
MegatronAdaptation.apply() MegatronAdaptation.apply()
# from .patch_utils import MegatronPatchesManager
# args = get_adaptor_args()
# for feature in FEATURES_LIST:
# if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
# feature.register_patches(MegatronPatchesManager, args)
# MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute() # MegatronAdaptation.post_execute()
@classmethod @classmethod
...@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_miscellaneous() self.patch_miscellaneous()
def patch_core_distributed(self): def patch_core_distributed(self):
# Mtp share embedding # mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads', MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads) _allreduce_word_embedding_grads)
def patch_core_models(self): def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init,
shared_embedding_or_output_weight,
)
from ..core.models.common.language_module.language_module import ( from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer, setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict tie_embeddings_and_output_weights_state_dict,
) )
from ..core.models.gpt.gpt_model import GPTModel
from ..training.utils import get_batch_on_this_tp_rank from ..training.utils import get_batch_on_this_tp_rank
# Embedding # LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register( MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer', 'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer) setup_embeddings_and_output_layer)
...@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict', 'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict) tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register( MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight) # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block # Transformer block. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__', MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper) transformer_block_init_wrapper)
...@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward) torch.compile(mode='max-autotune-no-cudagraphs'),
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
# VocabParallelCrossEntropy # VocabParallelCrossEntropy
...@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec", MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec) get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
...@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper, parallel_mlp_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
# ParallelAttention
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper, parallel_attention_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
......
...@@ -148,11 +148,29 @@ class MegatronPatchesManager: ...@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info = {} patches_info = {}
@staticmethod @staticmethod
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False): def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info: if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy) MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else: else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch) MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod @staticmethod
def apply_patches(): def apply_patches():
......
...@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf ...@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0] model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights or getattr(config, 'num_nextn_predict_layers', 0):
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight() weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr) orig_grad = getattr(weight, grad_attr)
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
...@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None: ...@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
# So we need to copy embedding weights from pre processing stage as initial parameters # So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases. # in these cases.
if not self.share_embeddings_and_output_weights and not getattr( if not self.share_embeddings_and_output_weights and not getattr(
self.config, 'num_nextn_predict_layers', 0 self.config, 'mtp_num_layers', 0
): ):
return return
...@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None: ...@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:
self.shared_embedding_or_output_weight().shared_embedding = True self.shared_embedding_or_output_weight().shared_embedding = True
if self.post_process and not self.pre_process: if (self.post_process or getattr(self, 'mtp_process', False)) and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage() assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first # set weights of the duplicated embedding to 0 here,
# stage's weights using all_reduce below. # then copy weights from pre processing stage using all_reduce below.
weight = self.shared_embedding_or_output_weight() weight = self.shared_embedding_or_output_weight()
weight.data.fill_(0) weight.data.fill_(0)
weight.shared = True weight.shared = True
...@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict( ...@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True, # layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have # the shared weights will be stored in embedding layer, and output layer will not have
# any weight. # any weight.
if self.post_process and getattr(self, 'num_nextn_predict_layers', False): if getattr(self, 'mtp_process', False):
# No output layer # No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys() assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return return
......
import warnings import warnings
from typing import Optional from typing import Optional, Union
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
...@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import ( ...@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules, MLASelfAttentionSubmodules,
) )
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import ( from megatron.core.transformer.transformer_layer import (
TransformerLayer, TransformerLayer,
TransformerLayerSubmodules, TransformerLayerSubmodules,
) )
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version from megatron.core.utils import is_te_min_version
try: try:
...@@ -36,6 +36,55 @@ try: ...@@ -36,6 +36,55 @@ try:
except ImportError: except ImportError:
warnings.warn('Apex is not installed.') warnings.warn('Apex is not installed.')
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_mtp_block_spec(
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
use_transformer_engine: bool,
) -> MultiTokenPredictionBlockSubmodules:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build = get_mtp_num_layers_to_build(config)
if num_layers_to_build == 0:
return None
if isinstance(spec, TransformerBlockSubmodules):
# get the spec for the last layer of decoder block
transformer_layer_spec = spec.layer_specs[-1]
elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer:
transformer_layer_spec = spec
else:
raise ValueError(f"Invalid spec: {spec}")
mtp_layer_spec = get_mtp_layer_spec(
transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine
)
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers
offset = get_mtp_layer_offset(config)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build]
if len(mtp_layer_specs) > 0:
assert (
len(mtp_layer_specs) == config.mtp_num_layers
), +f"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs)
else:
mtp_block_spec = None
return mtp_block_spec
def get_gpt_layer_with_flux_spec( def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None, num_experts: Optional[int] = None,
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os import os
import logging
from typing import Literal, Optional
from functools import wraps
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Literal, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear from megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
tie_output_layer_state_dict,
tie_word_embeddings_state_dict,
)
def gpt_model_init( class GPTModel(LanguageModule):
self, """GPT Transformer language model.
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
) -> None:
super(GPTModel, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
if self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0):
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
column_parallel_linear_impl = FluxColumnParallelLinear
else:
column_parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = column_parallel_linear_impl(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
# add mtp
if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec")
self.mtp_spec = self.config.mtp_spec
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
self.config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
)
for i in range(self.num_nextn_predict_layers)
]
)
if self.pre_process or self.post_process: Args:
self.setup_embeddings_and_output_layer() config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
if has_config_logger_enabled(self.config): def __init__(
log_config_to_disk( self,
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt' config: TransformerConfig,
) transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
mtp_block_spec: Optional[ModuleSpec] = None,
) -> None:
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.mtp_block_spec = mtp_block_spec
self.mtp_process = mtp_block_spec is not None
if self.pre_process or self.mtp_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
def shared_embedding_or_output_weight(self) -> Tensor: # Cache for RoPE tensors which do not change between iterations.
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True. self.rotary_pos_emb_cache = {}
Returns: # Transformer.
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight self.decoder = TransformerBlock(
""" config=self.config,
if self.pre_process or (self.post_process and getattr(self.config, 'num_nextn_predict_layers', 0)): spec=transformer_layer_spec,
return self.embedding.word_embeddings.weight pre_process=self.pre_process,
elif self.post_process: post_process=self.post_process,
return self.output_layer.weight
return None
def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0:
return (
[input_ids],
[labels],
[position_ids],
[attention_mask],
) )
return ( if self.mtp_process:
tensor_slide(input_ids, self.num_nextn_predict_layers), self.mtp = MultiTokenPredictionBlock(config=self.config, spec=self.mtp_block_spec)
tensor_slide(labels, self.num_nextn_predict_layers),
generate_nextn_position_ids(position_ids, self.num_nextn_predict_layers), # Output
# not compatible with ppo attn_mask if self.post_process or self.mtp_process:
tensor_slide(attention_mask, self.num_nextn_predict_layers, dims=[-2, -1]),
) if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
parallel_linear_impl = FluxColumnParallelLinear
else:
parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = parallel_linear_impl(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def generate_nextn_position_ids(tensor, slice_num): if self.pre_process or self.post_process:
slides = tensor_slide(tensor, slice_num) self.setup_embeddings_and_output_layer()
if slides[0] is None:
return slides
for idx in range(1, len(slides)): if has_config_logger_enabled(self.config):
slides[idx] = regenerate_position_ids(slides[idx], idx) log_config_to_disk(
return slides self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
def regenerate_position_ids(tensor, offset): See megatron.model.transformer.set_input_tensor()
if tensor is None:
return None
tensor = tensor.clone() Args:
for i in range(tensor.size(0)): input_tensor (Tensor): Sets the input tensor for the model.
row = tensor[i] """
zero_mask = (row == 0) # 两句拼接情形 # This is usually handled in schedules.py but some inference code still
if zero_mask.any(): # gives us non-lists or None
first_zero_idx = torch.argmax(zero_mask.int()).item() if not isinstance(input_tensor, list):
tensor[i, :first_zero_idx] = torch.arange(first_zero_idx) input_tensor = [input_tensor]
else:
tensor[i] = tensor[i] - offset
return tensor
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args: assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
runtime_gather_output (bool): Gather output at runtime. Default None means self.decoder.set_input_tensor(input_tensor[0])
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps def forward(
input_ids, labels, position_ids, attention_mask = slice_inputs(
self, self,
input_ids, input_ids: Tensor,
labels, position_ids: Tensor,
position_ids, attention_mask: Tensor,
attention_mask decoder_input: Tensor = None,
) labels: Tensor = None,
inference_params: InferenceParams = None,
# Decoder embedding. packed_seq_params: PackedSeqParams = None,
if decoder_input is not None: extra_block_kwargs: dict = None,
pass runtime_gather_output: Optional[bool] = None,
elif self.pre_process: loss_mask: Optional[Tensor] = None,
decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0]) ) -> Tensor:
else: """Forward function of the GPT Model This function passes the input tensors
# intermediate stage of pipeline through the embedding layer, and then the decoeder and finally into the post
# decoder will get hidden_states from encoder.input_tensor processing layer (optional).
decoder_input = None
It either returns the Loss values if labels are given or the final hidden units
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None Args:
rotary_pos_cos = None runtime_gather_output (bool): Gather output at runtime. Default None means
rotary_pos_sin = None `parallel_output` arg in the constructor will be used.
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: """
if not self.training and self.config.flash_decode and inference_params: # If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Flash decoding uses precomputed cos and sin for RoPE # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length, # Decoder embedding.
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), if decoder_input is not None:
) pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else: else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( # intermediate stage of pipeline
inference_params, self.decoder, decoder_input, self.config, packed_seq_params # decoder will get hidden_states from encoder.input_tensor
) decoder_input = None
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len, # Rotary positional embeddings (embedding is None for PP intermediate devices)
packed_seq=packed_seq_params is not None rotary_pos_emb = None
and packed_seq_params.qkv_format == 'thd', rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
) )
if ( else:
(self.config.enable_cuda_graph or self.config.flash_decode) sequence_len_offset = None
and rotary_pos_cos is not None
and inference_params # Run decoder.
): hidden_states = self.decoder(
sequence_len_offset = torch.tensor( hidden_states=decoder_input,
[inference_params.sequence_len_offset] * inference_params.current_batch_size, attention_mask=attention_mask,
dtype=torch.int32, inference_params=inference_params,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
) )
else:
sequence_len_offset = None # logits and loss
output_weight = None
# Run decoder. if self.share_embeddings_and_output_weights:
hidden_states = self.decoder( output_weight = self.shared_embedding_or_output_weight()
hidden_states=decoder_input,
attention_mask=attention_mask[0], if self.mtp_process:
inference_params=inference_params, hidden_states = self.mtp(
rotary_pos_emb=rotary_pos_emb, input_ids=input_ids,
rotary_pos_cos=rotary_pos_cos, position_ids=position_ids,
rotary_pos_sin=rotary_pos_sin, labels=labels,
packed_seq_params=packed_seq_params, loss_mask=loss_mask,
sequence_len_offset=sequence_len_offset, hidden_states=hidden_states,
**(extra_block_kwargs or {}), attention_mask=attention_mask,
) inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
if not self.post_process: rotary_pos_cos=rotary_pos_cos,
return hidden_states rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
# logits and loss sequence_len_offset=sequence_len_offset,
output_weight = None embedding=self.embedding,
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
loss = 0
# Multi token prediction module
if self.num_nextn_predict_layers and self.training:
mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i](
mtp_hidden_states, # [s,b,h]
input_ids[i + 1],
position_ids[i + 1] if position_ids[0] is not None else None,
attention_mask[i + 1] if attention_mask[0] is not None else None,
labels[i + 1] if labels[0] is not None else None,
inference_params,
packed_seq_params,
extra_block_kwargs,
embedding_layer=self.embedding,
output_layer=self.output_layer, output_layer=self.output_layer,
output_weight=output_weight, output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=self.compute_language_model_loss,
**(extra_block_kwargs or {}),
) )
loss += self.mtp_loss_scale / self.num_nextn_predict_layers * mtp_loss if (
self.mtp_process is not None
if ( and getattr(self.decoder, "main_final_layernorm", None) is not None
self.num_nextn_predict_layers ):
and getattr(self.decoder, "main_final_layernorm", None) is not None # move block main model final norms here
): hidden_states = self.decoder.main_final_layernorm(hidden_states)
# move block main model final norms here
hidden_states = self.decoder.main_final_layernorm(hidden_states) if not self.post_process:
return hidden_states
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output logits, _ = self.output_layer(
) hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids[0],
'position_ids': position_ids[0],
'attention_mask': attention_mask[0],
'decoder_input': decoder_input,
'logits': logits,
}
) )
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels[0] is None: if has_config_logger_enabled(self.config):
# [s b h] => [b s h] payload = OrderedDict(
return logits.transpose(0, 1).contiguous() {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def shared_embedding_or_output_weight(self) -> Tensor:
"""Gets the embedding weight or output logit weights when share input embedding and
output weights set to True or when use Multi-Token Prediction (MTP) feature.
Returns:
Tensor: During pre processing or MTP process it returns the input embeddings weight.
Otherwise, during post processing it returns the final output layers weight.
"""
if self.pre_process or self.mtp_process:
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
# So there will be both embedding layer and output layer in the mtp process stage.
# In this case, if share_embeddings_and_output_weights is True, the shared weights
# will be stored in embedding layer, and output layer will not have any weight.
assert hasattr(
self, 'embedding'
), f"embedding is needed in this pipeline stage, but it is not initialized."
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.output_layer.weight
return None
loss += self.compute_language_model_loss(labels[0], logits) def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
# Multi-Token Prediction (MTP) need both embedding layer and output layer in
# mtp process stage.
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
# embedding layer in the mtp process stage and tie it to the embedding in the pre
# processing stage.
# Also, if MTP is not placed in the post processing stage, we need to maintain a copy
# of output layer in the mtp process stage and tie it to the output layer in the post
# processing stage.
if self.mtp_process and not self.pre_process:
emb_weight_key = f'{prefix}embedding.word_embeddings.weight'
emb_weight = self.embedding.word_embeddings.weight
tie_word_embeddings_state_dict(sharded_state_dict, emb_weight, emb_weight_key)
if self.mtp_process and not self.post_process:
# We only need to tie the output layer weight if share_embeddings_and_output_weights
# is False. Because if share_embeddings_and_output_weights is True, the shared weight
# will be stored in embedding layer, and output layer will not have any weight.
if not self.share_embeddings_and_output_weights:
output_layer_weight_key = f'{prefix}output_layer.weight'
output_layer_weight = self.output_layer.weight
tie_output_layer_state_dict(
sharded_state_dict, output_layer_weight, output_layer_weight_key
)
return loss return sharded_state_dict
import torch
from functools import wraps
from dcu_megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
def forward_step_wrapper(fn):
@wraps(fn)
def wrapper(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs,
):
output, num_tokens = fn(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs
)
if not isinstance(input_tensor, list):
# unwrap_output_tensor True
output_tensor = output
else:
output_tensor = output[0]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output, num_tokens
return wrapper
\ No newline at end of file
from .layers import ( from .layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear, FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init_wrapper,
) )
\ No newline at end of file
import os import os
import copy
import socket import socket
import warnings import warnings
from functools import wraps
from typing import Callable, List, Optional from typing import Callable, List, Optional
if int(os.getenv("USE_FLUX_OVERLAP", "0")): try:
try: import flux
import flux except ImportError:
from dcu_megatron.core.utils import is_flux_min_version raise ImportError("flux is NOT installed")
except ImportError:
raise ImportError("flux is NOT installed")
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_global_memory_buffer, get_global_memory_buffer,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from megatron.core.utils import ( from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
is_torch_min_version,
prepare_input_tensors_for_wgrad_compute
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.core.tensor_parallel.mappings import ( from megatron.core.tensor_parallel.mappings import (
_reduce,
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
) )
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import ( from megatron.core.tensor_parallel import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
...@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd, custom_fwd,
custom_bwd, custom_bwd,
dist_all_gather_func, dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
...@@ -61,74 +41,6 @@ except ImportError: ...@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available = False _grad_accum_fusion_available = False
def vocab_parallel_embedding_init_wrapper(fn):
@wraps(fn)
def wrapper(self,
*args,
skip_weight_param_allocation: bool = False,
**kwargs
):
if (
skip_weight_param_allocation
and "config" in kwargs
and hasattr(kwargs["config"], "perform_initialization")
):
config = copy.deepcopy(kwargs["config"])
config.perform_initialization = False
kwargs["config"] = config
fn(self, *args, **kwargs)
if skip_weight_param_allocation:
self.weight = None
return wrapper
@torch.compile(mode='max-autotune-no-cudagraphs')
def vocab_parallel_embedding_forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def get_tensor_model_parallel_node_size(group=None): def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数 """ 获取节点数
""" """
......
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from torch import Tensor
from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import (
all_gather_last_dim_from_tensor_parallel_region,
scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor
SUPPORTED_ATTN_MASK = [
AttnMaskType.padding,
AttnMaskType.causal,
AttnMaskType.no_mask,
AttnMaskType.padding_causal,
]
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDelayedScaling,
TENorm,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
from megatron.core.transformer.torch_norm import WrappedTorchNorm
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def tie_word_embeddings_state_dict(
sharded_state_dict: ShardedStateDict, word_emb_weight: Tensor, word_emb_weight_key: str
) -> None:
"""tie the embedding of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
word_emb_weight (Tensor): weight of the word embedding.
word_emb_weight_key (str): key of the word embedding in the sharded state dict.
Returns: None, acts in-place
"""
mtp_word_emb_replica_id = (
1, # copy of embedding in pre processing stage
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
assert word_emb_weight_key in sharded_state_dict
del sharded_state_dict[word_emb_weight_key]
sharded_state_dict[word_emb_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=word_emb_weight,
key=word_emb_weight_key,
replica_id=mtp_word_emb_replica_id,
allow_shape_mismatch=True,
)
def tie_output_layer_state_dict(
sharded_state_dict: ShardedStateDict, output_layer_weight: Tensor, output_layer_weight_key: str
) -> None:
"""tie the output layer of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
output_layer_weight (Tensor): weight of the output layer.
output_layer_weight_key (str): key of the output layer in the sharded state dict.
Returns: None, acts in-place
"""
mtp_output_layer_replica_id = (
1, # copy of output layer in post processing stage
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
assert output_layer_weight_key in sharded_state_dict
del sharded_state_dict[output_layer_weight_key]
sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=output_layer_weight,
key=output_layer_weight_key,
replica_id=mtp_output_layer_replica_id,
allow_shape_mismatch=True,
)
def roll_tensor(tensor, shifts=-1, dims=-1):
"""Roll the tensor input along the given dimension(s).
Inserted elements are set to be 0.0.
"""
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
rolled_tensor.select(dims, shifts).fill_(0)
return rolled_tensor, rolled_tensor.sum()
class MTPLossLoggingHelper:
"""Helper class for logging MTP losses."""
tracker = {}
@staticmethod
def save_loss_to_tracker(
loss: torch.Tensor,
layer_number: int,
num_layers: int,
reduce_group: torch.distributed.ProcessGroup = None,
avg_group: torch.distributed.ProcessGroup = None,
):
"""Save the mtp loss for logging.
Args:
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
# Skip mtp loss logging if layer_number is None.
if layer_number is None:
return
tracker = MTPLossLoggingHelper.tracker
if "values" not in tracker:
tracker["values"] = torch.zeros(num_layers, device=loss.device)
tracker["values"][layer_number] += loss.detach()
tracker["reduce_group"] = reduce_group
tracker["avg_group"] = avg_group
def clean_loss_in_tracker():
"""Clear the mtp losses."""
tracker = MTPLossLoggingHelper.tracker
tracker["values"].zero_()
tracker["reduce_group"] = None
tracker["avg_group"] = None
def reduce_loss_in_tracker():
"""Collect and reduce the mtp losses across ranks."""
tracker = MTPLossLoggingHelper.tracker
if "values" not in tracker:
return
values = tracker["values"]
# Reduce mtp losses across ranks.
if tracker.get('reduce_group') is not None:
torch.distributed.all_reduce(values, group=tracker.get('reduce_group'))
if tracker.get('avg_group') is not None:
torch.distributed.all_reduce(
values, group=tracker['avg_group'], op=torch.distributed.ReduceOp.AVG
)
def track_mtp_metrics(loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None):
"""Track the Multi-Token Prediction (MTP) metrics for logging."""
MTPLossLoggingHelper.reduce_loss_in_tracker()
tracker = MTPLossLoggingHelper.tracker
if "values" not in tracker:
return
mtp_losses = tracker["values"] * loss_scale
mtp_num_layers = mtp_losses.shape[0]
for i in range(mtp_num_layers):
name = f"mtp_{i+1} loss"
loss = mtp_losses[i]
if total_loss_dict is not None:
total_loss_dict[name] = loss
if writer is not None:
writer.add_scalar(name, loss, iteration)
if wandb_writer is not None:
wandb_writer.log({f"{name}": loss}, iteration)
MTPLossLoggingHelper.clean_loss_in_tracker()
@dataclass
class MultiTokenPredictionLayerSubmodules:
"""
Dataclass for specifying the submodules of a MultiTokenPrediction module.
Args:
hnorm (Union[ModuleSpec, type]): Specification or instance of the
hidden states normalization to be applied.
enorm (Union[ModuleSpec, type]): Specification or instance of the
embedding normalization to be applied.
eh_proj (Union[ModuleSpec, type]): Specification or instance of the
linear projection to be applied.
transformer_layer (Union[ModuleSpec, type]): Specification
or instance of the transformer block to be applied.
"""
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
layer_norm: Union[ModuleSpec, type] = None
def get_mtp_layer_spec(
transformer_layer_spec: ModuleSpec, use_transformer_engine: bool
) -> ModuleSpec:
"""Get the MTP layer spec.
Returns:
ModuleSpec: Module specification with TE modules
"""
if use_transformer_engine:
assert HAVE_TE, "transformer_engine should be installed if use_transformer_engine is True"
layer_norm_impl = TENorm
column_parallel_linear_impl = TEColumnParallelLinear
else:
layer_norm_impl = LNImpl
column_parallel_linear_impl = ColumnParallelLinear
mtp_layer_spec = ModuleSpec(
module=MultiTokenPredictionLayer,
submodules=MultiTokenPredictionLayerSubmodules(
enorm=layer_norm_impl,
hnorm=layer_norm_impl,
eh_proj=column_parallel_linear_impl,
transformer_layer=transformer_layer_spec,
layer_norm=layer_norm_impl,
),
)
return mtp_layer_spec
def get_mtp_layer_offset(config: TransformerConfig) -> int:
"""Get the offset of the MTP layer."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
return 0
def get_mtp_num_layers_to_build(config: TransformerConfig) -> int:
"""Get the number of MTP layers to build."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
if mpu.is_pipeline_last_stage():
return config.mtp_num_layers if config.mtp_num_layers else 0
else:
return 0
class MTPLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
@staticmethod
def forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor):
"""Preserve the mtp by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
mtp_loss (torch.Tensor): The mtp loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(mtp_loss)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for mtp loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss
gradient.
"""
(mtp_loss,) = ctx.saved_tensors
mtp_loss_backward_scale = MTPLossAutoScaler.main_loss_backward_scale
scaled_mtp_loss_grad = torch.ones_like(mtp_loss) * mtp_loss_backward_scale
return grad_output, scaled_mtp_loss_grad
@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the mtp loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
matches the scale of the main_loss.
"""
MTPLossAutoScaler.main_loss_backward_scale = scale
class MultiTokenPredictionLayer(MegatronModule):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredictionLayerSubmodules,
layer_number: int = 1,
):
super().__init__(config=config)
self.sequence_parallel = config.sequence_parallel
self.submodules = submodules
self.layer_number = layer_number
self_attention_spec = self.submodules.transformer_layer.submodules.self_attention
attn_mask_type = self_attention_spec.params.get('attn_mask_type', '')
assert attn_mask_type in SUPPORTED_ATTN_MASK, (
f"Multi-Token Prediction (MTP) is not jet supported with "
+ f"{attn_mask_type} attention mask type."
+ f"The supported attention mask types are {SUPPORTED_ATTN_MASK}."
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation
# of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input,
# so the input's shape is [s, b, 2*h].
# The output will be send to the following transformer layer,
# so the output's shape should be [s, b, h].
self.eh_proj = build_module(
self.submodules.eh_proj,
self.config.hidden_size * 2,
self.config.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
)
self.transformer_layer = build_module(self.submodules.transformer_layer, config=self.config)
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
def forward(
self,
decoder_input: Tensor,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
):
"""
Perform the forward pass through the MTP layer.
Args:
hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
At the (k - 1)-th MTP module, the i-th element of decoder input is
the embedding of (i + K)-th tocken.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert context is None, f"multi token prediction + cross attention is not yet supported."
assert (
packed_seq_params is None
), f"multi token prediction + sequence packing is not yet supported."
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red
)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context, fp8_context:
decoder_input = self.enorm(decoder_input)
decoder_input = make_viewless_tensor(
inp=decoder_input, requires_grad=True, keep_graph=True
)
hidden_states = self.hnorm(hidden_states)
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True
)
# At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states
# and the (i + K)-th tocken's embedding, and combine them with linear projection.
hidden_states = torch.cat((decoder_input, hidden_states), -1)
hidden_states, _ = self.eh_proj(hidden_states)
# For tensor parallel, all gather after linear_fc.
hidden_states = all_gather_last_dim_from_tensor_parallel_region(hidden_states)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if self.sequence_parallel:
hidden_states = scatter_to_sequence_parallel_region(hidden_states)
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
# Layer norm before shared head layer.
hidden_states = self.final_layernorm(hidden_states)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
return hidden_states
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the multi token prediction layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
return sharded_state_dict
@dataclass
class MultiTokenPredictionBlockSubmodules:
"""
Dataclass for specifying the submodules of a multi token prediction block.
This class defines the structure for configuring the layers, allowing for
flexible and customizable architecture designs.
Args:
layer_specs (List[ModuleSpec], optional): A list of module specifications for
the layers within the multi token prediction block. Each specification typically
defines a complete multi token prediction layer (e.g., shared embedding,
projection matrix, transformer block, shared output head).
"""
layer_specs: List[ModuleSpec] = None
def _get_mtp_block_submodules(
config: TransformerConfig, spec: Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]
) -> MultiTokenPredictionBlockSubmodules:
"""
Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.
Args:
config (TransformerConfig): Configuration object for the transformer model.
spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the
multi token prediction block submodules.
Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec.
Returns:
MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block.
"""
# Transformer block submodules.
if isinstance(spec, MultiTokenPredictionBlockSubmodules):
return spec
elif isinstance(spec, ModuleSpec):
if issubclass(spec.module, MultiTokenPredictionBlock):
return spec.submodules
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
raise Exception(f"specialize for {type(spec).__name__}.")
class MultiTokenPredictionBlock(MegatronModule):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def __init__(
self, config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec]
):
super().__init__(config=config)
self.submodules = _get_mtp_block_submodules(config, spec)
self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
self._build_layers()
assert len(self.layers) > 0, "MultiTokenPredictionBlock must have at least one layer."
def _build_layers(self):
def build_layer(layer_spec, layer_number):
return build_module(layer_spec, config=self.config, layer_number=layer_number)
self.layers = torch.nn.ModuleList(
[
build_layer(layer_spec, i + 1)
for i, layer_spec in enumerate(self.submodules.layer_specs)
]
)
def forward(
self,
input_ids: Tensor,
position_ids: Tensor,
hidden_states: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None,
embedding=None,
output_layer=None,
output_weight: Optional[torch.Tensor] = None,
compute_language_model_loss=None,
) -> Tensor:
"""
Perform the forward pass through all of the MTP modules.
Args:
hidden_states (Tensor): Hidden states for input token with the shape [s, b, h]
where s is the sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
Returns:
(Tensor): The mtp loss tensor of shape [b, s].
"""
assert (
labels is not None
), f"labels should not be None for calculating multi token prediction loss."
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(labels)
hidden_states_main_model = hidden_states
for layer_number in range(len(self.layers)):
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1)
# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
# norm, linear projection and transformer
hidden_states = self.layers[layer_number](
decoder_input=decoder_input,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
# output
mtp_logits, _ = output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
labels, _ = roll_tensor(labels, shifts=-1, dims=-1)
loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1)
mtp_loss = compute_language_model_loss(labels, mtp_logits)
mtp_loss = loss_mask * mtp_loss
if self.training:
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_tensor_and_context_parallel_group(),
)
mtp_loss_scale = self.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states_main_model = MTPLossAutoScaler.apply(
hidden_states_main_model, mtp_loss_scale * mtp_loss
)
else:
hidden_states_main_model = MTPLossAutoScaler.apply(
hidden_states_main_model, mtp_loss_scale * mtp_loss / num_tokens
)
return hidden_states_main_model
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the multi token prediction module.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction module.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
layer_prefix = f'{prefix}layers.'
for layer in self.layers:
offset = get_mtp_layer_offset(self.config)
sharded_prefix = f'{layer_prefix}{layer.layer_number - 1 }.'
state_dict_prefix = f'{layer_prefix}{layer.layer_number - 1 - offset}.'
sharded_pp_offset = []
layer_sharded_state_dict = layer.sharded_state_dict(
state_dict_prefix, sharded_pp_offset, metadata
)
replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix)
sharded_state_dict.update(layer_sharded_state_dict)
return sharded_state_dict
...@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn): ...@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config'] config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "num_nextn_predict_layers", 0) > 0: if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None self.final_layernorm = None
......
from typing import Optional
from functools import wraps
from dataclasses import dataclass from dataclasses import dataclass
from megatron.training import get_args
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
fn(self)
args = get_args()
"""Number of Multi-Token Prediction (MTP) Layers."""
self.mtp_num_layers = args.mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor
##################
# flux
##################
self.flux_transpose_weight = args.flux_transpose_weight
return wrapper
@dataclass @dataclass
class ExtraTransformerConfig: class ExtraTransformerConfig:
################## ##################
# multi-token prediction # multi-token prediction
################## ##################
num_nextn_predict_layers: int = 0 mtp_num_layers: Optional[int] = None
"""The number of multi-token prediction layers""" """Number of Multi-Token Prediction (MTP) Layers."""
mtp_loss_scale: float = 0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm: bool = False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer: bool = False
"""Whether to recompute mtp layer"""
share_mtp_embedding_and_output_weight: bool = False mtp_loss_scaling_factor: Optional[float] = None
"""share embedding and output weight with mtp layer.""" """Weighting factor of Multi-Token Prediction (MTP) loss."""
################## ##################
# flux # flux
......
...@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser): ...@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def _add_mtp_args(parser): def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction') group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num') group.add_argument('--mtp-num-layers', type=int, default=None,
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale') help='Number of Multi-Token Prediction (MTP) Layers.'
group.add_argument('--recompute-mtp-norm', action='store_true', default=False, 'MTP extends the prediction scope to multiple future tokens at each position.'
help='Multi-Token prediction recompute norm') 'This MTP implementation sequentially predict additional tokens '
group.add_argument('--recompute-mtp-layer', action='store_true', default=False, 'by using D sequential modules to predict D additional tokens.')
help='Multi-Token prediction recompute layer') group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, help='Scaling factor of Multi-Token Prediction (MTP) loss. '
help='Main model share embedding and output weight with mtp layer.') 'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser return parser
......
...@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args = get_args() args = get_args()
def _broadcast(item): def _broadcast(item):
if item is not None: if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
batch = { batch = {
'tokens': data["tokens"].cuda(non_blocking = True), 'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True) 'position_ids': data["position_ids"].cuda(non_blocking = True)
} }
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels']) _broadcast(batch['labels'])
_broadcast(batch['loss_mask']) _broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage(): elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers: # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
else: else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64, tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
device = torch.cuda.current_device()) labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
dtype = torch.int64, if args.create_attention_mask_in_dataloader:
device = torch.cuda.current_device()) attention_mask=torch.empty(
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), (args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
) )
else: else:
attention_mask=None attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
dtype = torch.int64,
device = torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
if args.pipeline_model_parallel_size == 1: _broadcast(labels)
_broadcast(tokens) _broadcast(loss_mask)
_broadcast(labels) _broadcast(attention_mask)
_broadcast(loss_mask) _broadcast(position_ids)
_broadcast(attention_mask)
_broadcast(position_ids) elif mpu.is_pipeline_first_stage():
labels=None
elif mpu.is_pipeline_first_stage(): loss_mask=None
labels=None
loss_mask=None _broadcast(tokens)
_broadcast(attention_mask)
_broadcast(tokens) _broadcast(position_ids)
_broadcast(attention_mask)
_broadcast(position_ids) elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
elif mpu.is_pipeline_last_stage(): # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
if args.num_nextn_predict_layers: # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens) _broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids) _broadcast(position_ids)
else: else:
position_ids = None tokens=None
position_ids=None
batch = {
'tokens': tokens, _broadcast(labels)
'labels': labels, _broadcast(loss_mask)
'loss_mask': loss_mask, _broadcast(attention_mask)
'attention_mask': attention_mask,
'position_ids': position_ids batch = {
} 'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch return batch
...@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ...@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_spec,
) )
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
from dcu_megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron import megatron_adaptor from dcu_megatron import megatron_adaptor
...@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the mtp layer spec # Define the mtp layer spec
if isinstance(transformer_layer_spec, TransformerBlockSubmodules): mtp_block_spec = None
mtp_transformer_layer_spec = transformer_layer_spec.layer_specs[-1] if args.mtp_num_layers is not None:
else: from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
mtp_transformer_layer_spec = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel( model = GPTModel(
config=config, config=config,
transformer_layer_spec=transformer_layer_spec, transformer_layer_spec=transformer_layer_spec,
...@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type, position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent, rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
) )
# model = torch.compile(model,mode='max-autotune-no-cudagraphs') # model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model) print_rank_0(model)
...@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args = get_args() args = get_args()
losses = output_tensor.float() losses = output_tensor.float()
if getattr(args, "num_nextn_predict_layers", 0) > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum() total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
...@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel): ...@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator').stop() timers('batch-generator').stop()
with stimer: with stimer:
output_tensor = model(tokens, position_ids, attention_mask, if args.use_legacy_models:
labels=labels) output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, loss_mask=loss_mask)
return output_tensor, partial(loss_func, loss_mask) return output_tensor, partial(loss_func, loss_mask)
...@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args): ...@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig( return GPTDatasetConfig(
random_seed=args.seed, random_seed=args.seed,
sequence_length=args.seq_length + getattr(args, "num_nextn_predict_layers", 0), sequence_length=args.seq_length,
blend=blend, blend=blend,
blend_per_split=blend_per_split, blend_per_split=blend_per_split,
split=args.split, split=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment