Commit 770fa304 authored by dongcl's avatar dongcl
Browse files

修改mtp

parent 8096abd4
import argparse
from .feature_manager import FEATURES_LIST
_ARGS = None
def process_args(parser):
parser.conflict_handler = 'resolve'
for feature in FEATURES_LIST:
feature.register_args(parser)
return parser
def get_adaptor_args():
global _ARGS
if _ARGS is None:
parser = argparse.ArgumentParser(description='Adaptor Arguments', allow_abbrev=False)
_ARGS, _ = process_args(parser).parse_known_args()
return _ARGS
from .pipeline_parallel.dualpipev_feature import DualpipeVFeature
FEATURES_LIST = [
# Pipeline Parallel features
DualpipeVFeature()
]
# modified from mindspeed
import argparse
class BaseFeature:
def __init__(self, feature_name: str, optimization_level: int = 2):
self.feature_name = feature_name.strip().replace('-', '_')
self.optimization_level = optimization_level
self.default_patches = self.optimization_level == 0
def register_args(self, parser):
pass
def pre_validate_args(self, args):
pass
def validate_args(self, args):
pass
def post_validate_args(self, args):
pass
def register_patches(self, patch_manager, args):
...
def incompatible_check(self, global_args, check_args):
if getattr(global_args, self.feature_name, None) and getattr(global_args, check_args, None):
raise AssertionError('{} and {} are incompatible.'.format(self.feature_name, check_args))
def dependency_check(self, global_args, check_args):
if getattr(global_args, self.feature_name, None) and not getattr(global_args, check_args, None):
raise AssertionError('{} requires {}.'.format(self.feature_name, check_args))
@staticmethod
def add_parser_argument_choices_value(parser, argument_name, new_choice):
for action in parser._actions:
exist_arg = isinstance(action, argparse.Action) and argument_name in action.option_strings
if exist_arg and action.choices is not None and new_choice not in action.choices:
action.choices.append(new_choice)
from argparse import ArgumentParser
from ..base_feature import BaseFeature
class MTPFeature(BaseFeature):
def __init__(self):
super().__init__('schedules-method')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--schedules-method', type=str,
default=None, choices=['dualpipev'])
def register_patches(self, patch_manager, args):
from ...core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
from ...core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict,
)
from ...core.models.gpt.gpt_model import GPTModel
from ...training.utils import get_batch_on_this_tp_rank
from ...core.pipeline_parallel.schedules import forward_step_wrapper
from ...core import transformer_block_init_wrapper
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
# LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
# Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
# Modified from mindspeed.
from argparse import ArgumentParser
from ..base_feature import BaseFeature
class DualpipeVFeature(BaseFeature):
def __init__(self):
super().__init__('schedules-method')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--schedules-method', type=str,
default=None, choices=['dualpipev'])
def validate_args(self, args):
if args.schedules_method == "dualpipev":
if args.num_layers_per_virtual_pipeline_stage is not None:
raise AssertionError(
"The dualpipev and virtual_pipeline are incompatible.")
if args.num_layers < args.pipeline_model_parallel_size * 2:
raise AssertionError(
'number of layers must be at least 2*pipeline_model_parallel_size in dualpipe')
num_micro_batch = args.global_batch_size // args.micro_batch_size // args.data_parallel_size
if num_micro_batch < args.pipeline_model_parallel_size * 2 - 1:
raise AssertionError(
"num_micro_batch should more than pipeline_model_parallel_size * 2 - 1")
def register_patches(self, patch_manager, args):
from megatron.training.utils import print_rank_0
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import forward_backward_pipelining_with_cutinhalf
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_chunks import (
get_model, dualpipev_fp16forward, get_num_layers_to_build, train_step, _allreduce_embedding_grads_wrapper)
if args.schedules_method == "dualpipev":
patch_manager.register_patch(
'megatron.training.training.get_model', get_model)
patch_manager.register_patch(
'megatron.training.training.train_step', train_step)
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving',
forward_backward_pipelining_with_cutinhalf)
patch_manager.register_patch(
'megatron.legacy.model.module.Float16Module.forward', dualpipev_fp16forward)
patch_manager.register_patch(
'megatron.core.transformer.transformer_block.get_num_layers_to_build', get_num_layers_to_build)
patch_manager.register_patch(
'megatron.training.utils.print_rank_last', print_rank_0)
patch_manager.register_patch(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads', _allreduce_embedding_grads_wrapper)
......@@ -5,6 +5,8 @@ import types
import argparse
import torch
from .adaptor_arguments import get_adaptor_args
class MegatronAdaptation:
"""
......@@ -21,6 +23,15 @@ class MegatronAdaptation:
for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute()
MegatronAdaptation.apply()
from .patch_utils import MegatronPatchesManager
args = get_adaptor_args()
for feature in FEATURES_LIST:
if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
feature.register_patches(MegatronPatchesManager, args)
MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute()
@classmethod
......@@ -87,47 +98,37 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_miscellaneous()
def patch_core_distributed(self):
# Mtp share embedding
# mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight
from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict,
)
from ..core.models.gpt.gpt_model import GPTModel
from ..training.utils import get_batch_on_this_tp_rank
# Embedding
# LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
......@@ -165,13 +166,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init)
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
# VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
......@@ -201,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......@@ -245,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper,
apply_wrapper=True)
# ParallelAttention
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
......
......@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False):
def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy)
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch)
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod
def apply_patches():
......
......@@ -28,20 +28,13 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if (
hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight()
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
......
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
def language_model_embedding_init_func(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
skip_weight_param_allocation: bool = False
):
"""Patch language model embeddings init."""
super(LanguageModelEmbedding, self).__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
skip_weight_param_allocation=skip_weight_param_allocation
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def language_model_embedding_forward(self,
input_ids: Tensor,
position_ids: Tensor,
tokentype_ids: int = None,
weight: Tensor = None) -> Tensor:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if weight is None:
if self.word_embeddings.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.word_embeddings.weight
word_embeddings = self.word_embeddings(input_ids, weight)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
import logging
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
def setup_embeddings_and_output_layer(self) -> None:
"""Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are
using pipeline parallelism and sharing word embeddings, and sets up param
attributes on the embedding and output layers.
"""
# Set `is_embedding_or_output_parameter` attribute.
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if self.post_process and self.output_layer.weight is not None:
self.output_layer.weight.is_embedding_or_output_parameter = True
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
if not self.share_embeddings_and_output_weights and not getattr(
self.config, 'mtp_num_layers', 0
):
return
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_output_weight().zero_out_wgrad = True
return
if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:
self.shared_embedding_or_output_weight().shared_embedding = True
if (self.post_process or getattr(self, 'mtp_process', False)) and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set weights of the duplicated embedding to 0 here,
# then copy weights from pre processing stage using all_reduce below.
weight = self.shared_embedding_or_output_weight()
weight.data.fill_(0)
weight.shared = True
weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def tie_embeddings_and_output_weights_state_dict(
self,
sharded_state_dict: ShardedStateDict,
output_layer_weight_key: str,
first_stage_word_emb_key: str,
) -> None:
"""Ties the embedding and output weights in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie
output_layer_weight_key (str): key of the output layer weight in the state dict.
This entry will be replaced with a tied version
first_stage_word_emb_key (str): this must be the same as the
ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place
"""
if not self.post_process:
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
if self.pre_process:
# Output layer is equivalent to the embedding already
return
# If use Multi-Token Prediction (MTP), we need maintain both embedding layer and output
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
if getattr(self, 'mtp_process', False):
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
# Replace the default output layer with a one sharing the weights with the embedding
del sharded_state_dict[output_layer_weight_key]
tensor = self.shared_embedding_or_output_weight()
last_stage_word_emb_replica_id = (
1, # copy of first stage embedding
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
......@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
......@@ -36,6 +36,55 @@ try:
except ImportError:
warnings.warn('Apex is not installed.')
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_mtp_block_spec(
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
use_transformer_engine: bool,
) -> MultiTokenPredictionBlockSubmodules:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build = get_mtp_num_layers_to_build(config)
if num_layers_to_build == 0:
return None
if isinstance(spec, TransformerBlockSubmodules):
# get the spec for the last layer of decoder block
transformer_layer_spec = spec.layer_specs[-1]
elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer:
transformer_layer_spec = spec
else:
raise ValueError(f"Invalid spec: {spec}")
mtp_layer_spec = get_mtp_layer_spec(
transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine
)
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers
offset = get_mtp_layer_offset(config)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build]
if len(mtp_layer_specs) > 0:
assert (
len(mtp_layer_specs) == config.mtp_num_layers
), +f"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs)
else:
mtp_block_spec = None
return mtp_block_spec
def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None,
......
This diff is collapsed.
# Modified from mindspeed.
import torch
from functools import wraps
from typing import List, Optional
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config
from megatron.legacy.model import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
from megatron.training.global_vars import get_args, get_timers
from megatron.training.utils import unwrap_model
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.legacy.model.module import fp32_to_float16, float16_to_fp32
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _allreduce_layernorm_grads
from .dualpipev_schedules import get_dualpipe_chunk
def dualpipev_fp16forward(self, *inputs, **kwargs):
is_pipeline_first_stage = mpu.is_pipeline_first_stage() and get_dualpipe_chunk() == 0
if is_pipeline_first_stage:
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
is_pipeline_last_stage = mpu.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
if is_pipeline_last_stage:
outputs = float16_to_fp32(outputs)
return outputs
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
pre_process, post_process = False, False
if mpu.is_pipeline_first_stage():
pre_process = True
args.dualpipev_first_chunk = True
first_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
first_model.model_type = model_type
model.append(first_model)
args.dualpipev_first_chunk = False
second_model = model_provider_func(
pre_process=post_process,
post_process=pre_process
)
second_model.model_type = model_type
model.append(second_model)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(
param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if wrap_with_ddp:
config = get_model_config(model[0])
ddp_config = DistributedDataParallelConfig(
grad_reduce_in_fp32=args.accumulate_allreduce_grads_in_fp32,
overlap_grad_reduce=args.overlap_grad_reduce,
use_distributed_optimizer=args.use_distributed_optimizer,
check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad,
bucket_size=args.ddp_bucket_size,
average_in_collective=args.ddp_average_in_collective)
model = [DDP(config,
ddp_config,
model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0))
for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
return model
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
rerun_state_machine = get_rerun_state_machine()
while rerun_state_machine.should_run_forward_backward(data_iterator):
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Vision gradients.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
dualpipev_last_stage = mpu.is_pipeline_first_stage(ignore_virtual=True)
if dualpipev_last_stage:
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
def get_num_layers_to_build(config: TransformerConfig) -> int:
num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
num_layers_to_build = num_layers_per_pipeline_rank // 2
return num_layers_to_build
def _allreduce_embedding_grads_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if get_args().schedules_method == 'dualpipev':
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if not get_args().untie_embeddings_and_output_weights:
raise NotImplementedError
else:
return
else:
return fn(*args, **kwargs)
return wrapper
from .modules.layers import linear_backward_wgrad_detach, ColumnParallelLinear, RowParallelLinear
from .modules.experts import group_mlp_forward_detach
from .transformer_layer import transformer_layer_forward_backward_overlaping
from .gpt_model import gpt_model_forward_backward_overlaping
from .vpp_schedules import forward_backward_pipelining_with_interleaving
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
def _make_param_hook(
self,
param: torch.nn.Parameter,
param_to_buffer,
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def param_hook(*unused):
if param.requires_grad and not getattr(param, 'skip_grad_accum', False):
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
# Maybe should called after weightgradstore.pop()
if self.ddp_config.overlap_grad_reduce:
param_to_buffer[param].register_grad_ready(param)
if getattr(param, 'skip_grad_accum', False):
param.skip_grad_accum = False
return param_hook
\ No newline at end of file
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import logging
from typing import Dict, Literal, Optional, Tuple, Union, List
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.packed_seq_params import PackedSeqParams
from .transformer_block import (
transformer_block_backward, transformer_block_forward_backward_overlaping,
transformer_block_forward
)
from .modules.utils import (
LayerGraph, detach_tensor, run_graph_backward
)
class ModelGraph:
def __init__(
self,
layer_graphs: List[LayerGraph],
block_output,
preprocess_graph: Tensor = None,
preprocess_detached_output: Tensor = None,
):
self.preprocess_graph = (preprocess_graph, preprocess_detached_output)
self.layer_graphs = layer_graphs
self.block_output = block_output
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
preprocess_graph = None
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
preprocess_graph = decoder_input
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
preprocess_graph = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
detached_block_input = detach_tensor(decoder_input)
# Run decoder.
hidden_states, layer_graphs = transformer_block_forward(
self.decoder,
hidden_states=detached_block_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states, ModelGraph(layer_graphs, hidden_states, preprocess_graph, detached_block_input)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
logits = logits.transpose(0, 1).contiguous()
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return logits, graph
loss = self.compute_language_model_loss(labels, logits)
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return loss, graph
def gpt_model_backward(
model_grad,
model_graph: ModelGraph,
):
block_input_grad = transformer_block_backward(model_grad, model_graph.layer_graphs)
if model_graph.preprocess_graph[0] is not None:
run_graph_backward(model_graph.preprocess_graph, block_input_grad, keep_graph=True, keep_grad=True)
return None
else:
return block_input_grad
def gpt_model_forward_backward_overlaping(
fwd_model,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
):
if extra_block_kwargs is None or extra_block_kwargs['bwd_model_graph'] is None:
return gpt_model_forward(
fwd_model, input_ids, position_ids, attention_mask, decoder_input, labels, inference_params,
packed_seq_params, extra_block_kwargs
)
bwd_model_grad, bwd_model_graph = extra_block_kwargs['bwd_model_grad'], extra_block_kwargs['bwd_model_graph'] # Fwd Model Decoder embedding.
if decoder_input is not None:
preprocess_graph = None
elif fwd_model.pre_process:
decoder_input = fwd_model.embedding(input_ids=input_ids, position_ids=position_ids)
preprocess_graph = decoder_input
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
preprocess_graph = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if fwd_model.position_embedding_type == 'rope':
rotary_seq_len = fwd_model.rotary_pos_emb.get_rotary_seq_len(
inference_params, fwd_model.decoder, decoder_input, fwd_model.config
)
rotary_pos_emb = fwd_model.rotary_pos_emb(rotary_seq_len)
detached_block_input = detach_tensor(decoder_input)
# Run transformer block fwd & bwd overlaping
(hidden_states, layer_graphs), block_input_grad, pp_comm_output \
= transformer_block_forward_backward_overlaping(
fwd_model.decoder,
detached_block_input,
attention_mask,
bwd_model_grad,
bwd_model_graph.layer_graphs,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
pp_comm_params=extra_block_kwargs['pp_comm_params'],
bwd_pp_comm_params=extra_block_kwargs['bwd_pp_comm_params']
)
if bwd_model_graph.preprocess_graph[0] is not None:
run_graph_backward(bwd_model_graph.preprocess_graph, block_input_grad, keep_grad=True, keep_graph=True)
if not fwd_model.post_process:
return hidden_states, ModelGraph(layer_graphs, hidden_states, preprocess_graph,
detached_block_input), pp_comm_output
# logits and loss
output_weight = None
if fwd_model.share_embeddings_and_output_weights:
output_weight = fwd_model.shared_embedding_or_output_weight()
logits, _ = fwd_model.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
logits = logits.transpose(0, 1).contiguous()
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return logits, graph, pp_comm_output
loss = fwd_model.compute_language_model_loss(labels, logits)
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return loss, graph, pp_comm_output
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
from megatron.training import get_args
from mindspeed.core.transformer.moe.comm_utils import async_all_to_all
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
AsyncAll2All_INPUT = []
AsyncAll2All_OUTPUT = []
def set_async_alltoall_inputs(*args):
AsyncAll2All_INPUT.append(args)
def get_async_alltoall_outputs():
return AsyncAll2All_OUTPUT.pop(0)
def launch_async_all2all():
global AsyncAll2All_INPUT
global AsyncAll2All_OUTPUT
if len(AsyncAll2All_INPUT) > 0:
input_, input_splits, output_splits, group = AsyncAll2All_INPUT.pop(0)
_, output, a2a_handle = async_all_to_all(
input_,
input_splits,
output_splits,
group
)
AsyncAll2All_OUTPUT.append((output, a2a_handle))
def launch_async_all2all_hook(_):
launch_async_all2all()
def attention_forward(
self,
hidden_states,
residual,
attention_mask=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
recompute_norm=False
):
# Optional Input Layer norm
def pre_norm(hidden_states):
args = get_args()
input_layernorm_output = self.input_layernorm(hidden_states)
if getattr(args, 'input_layernorm_in_fp32', False):
input_layernorm_output = input_layernorm_output.float()
return input_layernorm_output
if recompute_norm:
self.norm_ckpt1 = CheckpointWithoutOutput()
input_layernorm_output = self.norm_ckpt1.checkpoint(pre_norm, False, hidden_states)
else:
input_layernorm_output = pre_norm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
if recompute_norm:
self.norm_ckpt1.discard_output()
hidden_states.register_hook(self.norm_ckpt1.recompute)
return hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment