Commit 770fa304 authored by dongcl's avatar dongcl
Browse files

修改mtp

parent 8096abd4
import argparse
from .feature_manager import FEATURES_LIST
_ARGS = None
def process_args(parser):
parser.conflict_handler = 'resolve'
for feature in FEATURES_LIST:
feature.register_args(parser)
return parser
def get_adaptor_args():
global _ARGS
if _ARGS is None:
parser = argparse.ArgumentParser(description='Adaptor Arguments', allow_abbrev=False)
_ARGS, _ = process_args(parser).parse_known_args()
return _ARGS
from .pipeline_parallel.dualpipev_feature import DualpipeVFeature
FEATURES_LIST = [
# Pipeline Parallel features
DualpipeVFeature()
]
# modified from mindspeed
import argparse
class BaseFeature:
def __init__(self, feature_name: str, optimization_level: int = 2):
self.feature_name = feature_name.strip().replace('-', '_')
self.optimization_level = optimization_level
self.default_patches = self.optimization_level == 0
def register_args(self, parser):
pass
def pre_validate_args(self, args):
pass
def validate_args(self, args):
pass
def post_validate_args(self, args):
pass
def register_patches(self, patch_manager, args):
...
def incompatible_check(self, global_args, check_args):
if getattr(global_args, self.feature_name, None) and getattr(global_args, check_args, None):
raise AssertionError('{} and {} are incompatible.'.format(self.feature_name, check_args))
def dependency_check(self, global_args, check_args):
if getattr(global_args, self.feature_name, None) and not getattr(global_args, check_args, None):
raise AssertionError('{} requires {}.'.format(self.feature_name, check_args))
@staticmethod
def add_parser_argument_choices_value(parser, argument_name, new_choice):
for action in parser._actions:
exist_arg = isinstance(action, argparse.Action) and argument_name in action.option_strings
if exist_arg and action.choices is not None and new_choice not in action.choices:
action.choices.append(new_choice)
from argparse import ArgumentParser
from ..base_feature import BaseFeature
class MTPFeature(BaseFeature):
def __init__(self):
super().__init__('schedules-method')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--schedules-method', type=str,
default=None, choices=['dualpipev'])
def register_patches(self, patch_manager, args):
from ...core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
from ...core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict,
)
from ...core.models.gpt.gpt_model import GPTModel
from ...training.utils import get_batch_on_this_tp_rank
from ...core.pipeline_parallel.schedules import forward_step_wrapper
from ...core import transformer_block_init_wrapper
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
# LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
# Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
# Modified from mindspeed.
from argparse import ArgumentParser
from ..base_feature import BaseFeature
class DualpipeVFeature(BaseFeature):
def __init__(self):
super().__init__('schedules-method')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--schedules-method', type=str,
default=None, choices=['dualpipev'])
def validate_args(self, args):
if args.schedules_method == "dualpipev":
if args.num_layers_per_virtual_pipeline_stage is not None:
raise AssertionError(
"The dualpipev and virtual_pipeline are incompatible.")
if args.num_layers < args.pipeline_model_parallel_size * 2:
raise AssertionError(
'number of layers must be at least 2*pipeline_model_parallel_size in dualpipe')
num_micro_batch = args.global_batch_size // args.micro_batch_size // args.data_parallel_size
if num_micro_batch < args.pipeline_model_parallel_size * 2 - 1:
raise AssertionError(
"num_micro_batch should more than pipeline_model_parallel_size * 2 - 1")
def register_patches(self, patch_manager, args):
from megatron.training.utils import print_rank_0
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import forward_backward_pipelining_with_cutinhalf
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_chunks import (
get_model, dualpipev_fp16forward, get_num_layers_to_build, train_step, _allreduce_embedding_grads_wrapper)
if args.schedules_method == "dualpipev":
patch_manager.register_patch(
'megatron.training.training.get_model', get_model)
patch_manager.register_patch(
'megatron.training.training.train_step', train_step)
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving',
forward_backward_pipelining_with_cutinhalf)
patch_manager.register_patch(
'megatron.legacy.model.module.Float16Module.forward', dualpipev_fp16forward)
patch_manager.register_patch(
'megatron.core.transformer.transformer_block.get_num_layers_to_build', get_num_layers_to_build)
patch_manager.register_patch(
'megatron.training.utils.print_rank_last', print_rank_0)
patch_manager.register_patch(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads', _allreduce_embedding_grads_wrapper)
......@@ -5,6 +5,8 @@ import types
import argparse
import torch
from .adaptor_arguments import get_adaptor_args
class MegatronAdaptation:
"""
......@@ -21,6 +23,15 @@ class MegatronAdaptation:
for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute()
MegatronAdaptation.apply()
from .patch_utils import MegatronPatchesManager
args = get_adaptor_args()
for feature in FEATURES_LIST:
if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
feature.register_patches(MegatronPatchesManager, args)
MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute()
@classmethod
......@@ -87,47 +98,37 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_miscellaneous()
def patch_core_distributed(self):
# Mtp share embedding
# mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init_wrapper,
shared_embedding_or_mtp_embedding_weight
from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict,
)
from ..core.models.gpt.gpt_model import GPTModel
from ..training.utils import get_batch_on_this_tp_rank
# Embedding
# LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gpt_model_init_wrapper,
apply_wrapper=True)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
......@@ -165,13 +166,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init)
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
# VocabParallelCrossEntropy
MegatronAdaptation.register('megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
......@@ -201,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......@@ -245,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper,
apply_wrapper=True)
# ParallelAttention
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
......
......@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False):
def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy)
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch)
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod
def apply_patches():
......
......@@ -28,20 +28,13 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if (
hasattr(model_module, "share_mtp_embedding_and_output_weight")
and model_module.share_mtp_embedding_and_output_weight
and config.num_nextn_predict_layers > 0
):
weight = model_module.shared_embedding_or_mtp_embedding_weight()
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
......
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
def language_model_embedding_init_func(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
skip_weight_param_allocation: bool = False
):
"""Patch language model embeddings init."""
super(LanguageModelEmbedding, self).__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
skip_weight_param_allocation=skip_weight_param_allocation
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def language_model_embedding_forward(self,
input_ids: Tensor,
position_ids: Tensor,
tokentype_ids: int = None,
weight: Tensor = None) -> Tensor:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if weight is None:
if self.word_embeddings.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.word_embeddings.weight
word_embeddings = self.word_embeddings(input_ids, weight)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
import logging
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
def setup_embeddings_and_output_layer(self) -> None:
"""Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are
using pipeline parallelism and sharing word embeddings, and sets up param
attributes on the embedding and output layers.
"""
# Set `is_embedding_or_output_parameter` attribute.
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if self.post_process and self.output_layer.weight is not None:
self.output_layer.weight.is_embedding_or_output_parameter = True
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
if not self.share_embeddings_and_output_weights and not getattr(
self.config, 'mtp_num_layers', 0
):
return
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_output_weight().zero_out_wgrad = True
return
if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:
self.shared_embedding_or_output_weight().shared_embedding = True
if (self.post_process or getattr(self, 'mtp_process', False)) and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set weights of the duplicated embedding to 0 here,
# then copy weights from pre processing stage using all_reduce below.
weight = self.shared_embedding_or_output_weight()
weight.data.fill_(0)
weight.shared = True
weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def tie_embeddings_and_output_weights_state_dict(
self,
sharded_state_dict: ShardedStateDict,
output_layer_weight_key: str,
first_stage_word_emb_key: str,
) -> None:
"""Ties the embedding and output weights in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie
output_layer_weight_key (str): key of the output layer weight in the state dict.
This entry will be replaced with a tied version
first_stage_word_emb_key (str): this must be the same as the
ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place
"""
if not self.post_process:
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
if self.pre_process:
# Output layer is equivalent to the embedding already
return
# If use Multi-Token Prediction (MTP), we need maintain both embedding layer and output
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
if getattr(self, 'mtp_process', False):
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
# Replace the default output layer with a one sharing the weights with the embedding
del sharded_state_dict[output_layer_weight_key]
tensor = self.shared_embedding_or_output_weight()
last_stage_word_emb_replica_id = (
1, # copy of first stage embedding
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=tensor,
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
)
......@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
......@@ -36,6 +36,55 @@ try:
except ImportError:
warnings.warn('Apex is not installed.')
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_mtp_block_spec(
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
use_transformer_engine: bool,
) -> MultiTokenPredictionBlockSubmodules:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build = get_mtp_num_layers_to_build(config)
if num_layers_to_build == 0:
return None
if isinstance(spec, TransformerBlockSubmodules):
# get the spec for the last layer of decoder block
transformer_layer_spec = spec.layer_specs[-1]
elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer:
transformer_layer_spec = spec
else:
raise ValueError(f"Invalid spec: {spec}")
mtp_layer_spec = get_mtp_layer_spec(
transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine
)
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers
offset = get_mtp_layer_offset(config)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build]
if len(mtp_layer_specs) > 0:
assert (
len(mtp_layer_specs) == config.mtp_num_layers
), +f"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs)
else:
mtp_block_spec = None
return mtp_block_spec
def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None,
......
import os
import logging
from typing import Literal, Optional
from functools import wraps
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import OrderedDict
from typing import Dict, Literal, Optional
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear
from megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from dcu_megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
tie_output_layer_state_dict,
tie_word_embeddings_state_dict,
)
def gpt_model_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
class GPTModel(LanguageModule):
"""GPT Transformer language model.
if (
self.post_process
and int(os.getenv("USE_FLUX_OVERLAP", "0"))
):
self.output_layer = FluxColumnParallelLinear(
self.config.hidden_size,
self.vocab_size,
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
rope_scaling_factor: float = 8.0,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
mtp_block_spec: Optional[ModuleSpec] = None,
) -> None:
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
self.mtp_block_spec = mtp_block_spec
self.mtp_process = mtp_block_spec is not None
if self.pre_process or self.mtp_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
init_method=self.config.init_method,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
rope_scaling_factor=rope_scaling_factor,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Cache for RoPE tensors which do not change between iterations.
self.rotary_pos_emb_cache = {}
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
)
if self.mtp_process:
self.mtp = MultiTokenPredictionBlock(config=self.config, spec=self.mtp_block_spec)
# Output
if self.post_process or self.mtp_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
parallel_linear_impl = FluxColumnParallelLinear
else:
parallel_linear_impl = tensor_parallel.ColumnParallelLinear
self.output_layer = parallel_linear_impl(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
......@@ -48,324 +186,239 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer=self.grad_output_buffer,
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
# add mtp
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
if self.num_nextn_predict_layers:
assert hasattr(self.config, "mtp_spec")
self.mtp_spec: ModuleSpec = self.config.mtp_spec
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
self.config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
)
for i in range(self.num_nextn_predict_layers)
]
)
if self.pre_process or self.post_process:
setup_mtp_embeddings(self)
return wrapper
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
"""
assert self.num_nextn_predict_layers > 0
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.mtp_layers[0].embedding.word_embeddings.weight
return None
def setup_mtp_embeddings(self):
"""
Share embedding layer in mtp layer.
"""
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
# Set `is_embedding_or_output_parameter` attribute.
for i in range(self.num_nextn_predict_layers):
if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None:
self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if not self.share_mtp_embedding_and_output_weight:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_mtp_embedding_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_mtp_embedding_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
for i in range(self.num_nextn_predict_layers):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0)
self.mtp_layers[i].embedding.word_embeddings.weight.shared = True
self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_mtp_embedding_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0:
return (
[input_ids],
[labels],
[position_ids],
[attention_mask],
)
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
return (
tensor_slide(input_ids, self.num_nextn_predict_layers),
tensor_slide(labels, self.num_nextn_predict_layers),
generate_nextn_position_ids(position_ids, self.num_nextn_predict_layers),
# not compatible with ppo attn_mask
tensor_slide(attention_mask, self.num_nextn_predict_layers, dims=[-2, -1]),
)
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
def generate_nextn_position_ids(tensor, slice_num):
slides = tensor_slide(tensor, slice_num)
if slides[0] is None:
return slides
for idx in range(1, len(slides)):
slides[idx] = regenerate_position_ids(slides[idx], idx)
return slides
def regenerate_position_ids(tensor, offset):
if tensor is None:
return None
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
tensor = tensor.clone()
for i in range(tensor.size(0)):
row = tensor[i]
zero_mask = (row == 0) # 两句拼接情形
if zero_mask.any():
first_zero_idx = torch.argmax(zero_mask.int()).item()
tensor[i, :first_zero_idx] = torch.arange(first_zero_idx)
else:
tensor[i] = tensor[i] - offset
return tensor
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
input_ids, labels, position_ids, attention_mask = slice_inputs(
def forward(
self,
input_ids,
labels,
position_ids,
attention_mask
)
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0])
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
else:
sequence_len_offset = None
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
else:
sequence_len_offset = None
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask[0],
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
loss = 0
# Multi token prediction module
if self.num_nextn_predict_layers and self.training:
if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight:
output_weight = self.output_layer.weight
output_weight.zero_out_wgrad = True
embedding_weight = self.shared_embedding_or_mtp_embedding_weight() if self.share_mtp_embedding_and_output_weight else None
mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i](
mtp_hidden_states, # [s,b,h]
input_ids[i + 1],
position_ids[i + 1] if position_ids[0] is not None else None,
attention_mask[i + 1] if attention_mask[0] is not None else None,
labels[i + 1] if labels[0] is not None else None,
inference_params,
packed_seq_params,
extra_block_kwargs,
embeding_weight=embedding_weight,
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if self.mtp_process:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
labels=labels,
loss_mask=loss_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
compute_language_model_loss=self.compute_language_model_loss,
**(extra_block_kwargs or {}),
)
loss += self.mtp_loss_scale / self.num_nextn_predict_layers * mtp_loss
if (
self.num_nextn_predict_layers
and getattr(self.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.main_final_layernorm(hidden_states)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids[0],
'position_ids': position_ids[0],
'attention_mask': attention_mask[0],
'decoder_input': decoder_input,
'logits': logits,
}
if (
self.num_nextn_predict_layers
and getattr(self.decoder, "main_final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.main_final_layernorm(hidden_states)
if not self.post_process:
return hidden_states
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels[0] is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return loss
def shared_embedding_or_output_weight(self) -> Tensor:
"""Gets the embedding weight or output logit weights when share input embedding and
output weights set to True or when use Multi-Token Prediction (MTP) feature.
Returns:
Tensor: During pre processing or MTP process it returns the input embeddings weight.
Otherwise, during post processing it returns the final output layers weight.
"""
if self.pre_process or self.mtp_process:
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
# So there will be both embedding layer and output layer in the mtp process stage.
# In this case, if share_embeddings_and_output_weights is True, the shared weights
# will be stored in embedding layer, and output layer will not have any weight.
assert hasattr(
self, 'embedding'
), f"embedding is needed in this pipeline stage, but it is not initialized."
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.output_layer.weight
return None
loss += self.compute_language_model_loss(labels[0], logits)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None
) -> ShardedStateDict:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
output_layer_extra_state_key = f'{prefix}output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
assert not (
output_extra_state and output_extra_state.data
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
# Multi-Token Prediction (MTP) need both embedding layer and output layer in
# mtp process stage.
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
# embedding layer in the mtp process stage and tie it to the embedding in the pre
# processing stage.
# Also, if MTP is not placed in the post processing stage, we need to maintain a copy
# of output layer in the mtp process stage and tie it to the output layer in the post
# processing stage.
if self.mtp_process and not self.pre_process:
emb_weight_key = f'{prefix}embedding.word_embeddings.weight'
emb_weight = self.embedding.word_embeddings.weight
tie_word_embeddings_state_dict(sharded_state_dict, emb_weight, emb_weight_key)
if self.mtp_process and not self.post_process:
# We only need to tie the output layer weight if share_embeddings_and_output_weights
# is False. Because if share_embeddings_and_output_weights is True, the shared weight
# will be stored in embedding layer, and output layer will not have any weight.
if not self.share_embeddings_and_output_weights:
output_layer_weight_key = f'{prefix}output_layer.weight'
output_layer_weight = self.output_layer.weight
tie_output_layer_state_dict(
sharded_state_dict, output_layer_weight, output_layer_weight_key
)
return loss
return sharded_state_dict
# Modified from mindspeed.
import torch
from functools import wraps
from typing import List, Optional
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config
from megatron.legacy.model import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
from megatron.training.global_vars import get_args, get_timers
from megatron.training.utils import unwrap_model
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.legacy.model.module import fp32_to_float16, float16_to_fp32
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _allreduce_layernorm_grads
from .dualpipev_schedules import get_dualpipe_chunk
def dualpipev_fp16forward(self, *inputs, **kwargs):
is_pipeline_first_stage = mpu.is_pipeline_first_stage() and get_dualpipe_chunk() == 0
if is_pipeline_first_stage:
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
is_pipeline_last_stage = mpu.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
if is_pipeline_last_stage:
outputs = float16_to_fp32(outputs)
return outputs
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
pre_process, post_process = False, False
if mpu.is_pipeline_first_stage():
pre_process = True
args.dualpipev_first_chunk = True
first_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
first_model.model_type = model_type
model.append(first_model)
args.dualpipev_first_chunk = False
second_model = model_provider_func(
pre_process=post_process,
post_process=pre_process
)
second_model.model_type = model_type
model.append(second_model)
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(
param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
if wrap_with_ddp:
config = get_model_config(model[0])
ddp_config = DistributedDataParallelConfig(
grad_reduce_in_fp32=args.accumulate_allreduce_grads_in_fp32,
overlap_grad_reduce=args.overlap_grad_reduce,
use_distributed_optimizer=args.use_distributed_optimizer,
check_for_nan_in_grad=args.check_for_nan_in_loss_and_grad,
bucket_size=args.ddp_bucket_size,
average_in_collective=args.ddp_average_in_collective)
model = [DDP(config,
ddp_config,
model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0))
for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
return model
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
rerun_state_machine = get_rerun_state_machine()
while rerun_state_machine.should_run_forward_backward(data_iterator):
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# Vision gradients.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
dualpipev_last_stage = mpu.is_pipeline_first_stage(ignore_virtual=True)
if dualpipev_last_stage:
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
def get_num_layers_to_build(config: TransformerConfig) -> int:
num_layers_per_pipeline_rank = (
config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
num_layers_to_build = num_layers_per_pipeline_rank // 2
return num_layers_to_build
def _allreduce_embedding_grads_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if get_args().schedules_method == 'dualpipev':
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if not get_args().untie_embeddings_and_output_weights:
raise NotImplementedError
else:
return
else:
return fn(*args, **kwargs)
return wrapper
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import contextlib
from functools import wraps
from typing import Iterator, List, Union
import torch
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.training import get_args
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config,
get_model_type,
)
from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, deallocate_output_tensor
from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel.p2p_communication import _communicate
from megatron.core.pipeline_parallel.schedules import backward_step, set_current_microbatch, custom_backward, finish_embedding_wgrad_compute
from megatron.core.models.gpt import GPTModel
from mindspeed.core.pipeline_parallel.fb_overlap.gpt_model import gpt_model_backward
from mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer import P2PCommParams
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# Types
Shape = Union[List[int], torch.Size]
LOSS_BACKWARD_SCALE = torch.tensor(1.0)
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunkid):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunkid
def get_dualpipe_chunk():
global _DUALPIPE_CHUNK
if _DUALPIPE_CHUNK is not None:
return _DUALPIPE_CHUNK
else:
raise AssertionError("_DUALPIPE_CHUNK is None")
def is_dualpipev_last_stgae(model_chunk_id):
return parallel_state.is_pipeline_first_stage() and model_chunk_id == 1
def send_forward(output_tensor: torch.Tensor, tensor_shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage():
return None
tensor_send_next = output_tensor
else:
if parallel_state.is_pipeline_first_stage():
return None
tensor_send_prev = output_tensor
if config.timers is not None:
config.timers('forward-send', log_level=2).start()
_, _, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op)
)
if config.timers is not None:
config.timers('forward-send').stop()
return fwd_wait_handles
def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_first_stage():
return None
tensor_send_prev = input_tensor_grad
else:
if parallel_state.is_pipeline_last_stage():
return None
tensor_send_next = input_tensor_grad
if config.timers is not None:
config.timers('backward-send', log_level=2).start()
_, _, reqs = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op)
)
if config.timers is not None:
config.timers('backward-send').stop()
return reqs
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
if model_chunk_id == 0:
recv_prev = True
else:
recv_next = True
if (parallel_state.is_pipeline_first_stage() and recv_prev) or (parallel_state.is_pipeline_last_stage() and recv_next):
fwd_wait_handles = None
return None, fwd_wait_handles
else:
if config.timers is not None:
config.timers('forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op),
)
if config.timers is not None:
config.timers('forward-recv').stop()
if recv_prev:
return tensor_recv_prev, fwd_wait_handles
else:
return tensor_recv_next, fwd_wait_handles
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
if model_chunk_id == 0:
recv_next = True
else:
recv_prev = True
if (parallel_state.is_pipeline_first_stage() and recv_prev) or (parallel_state.is_pipeline_last_stage() and recv_next):
output_tensor_grad = None
bwd_wait_handles = None
return output_tensor_grad, bwd_wait_handles
else:
if config.timers is not None:
config.timers('backward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, bwd_wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
wait_on_reqs=(not async_op)
)
if config.timers is not None:
config.timers('backward-recv').stop()
if recv_prev:
return tensor_recv_prev, bwd_wait_handles
else:
return tensor_recv_next, bwd_wait_handles
def send_forward_recv_forward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig,
model_chunk_id,
async_op=False
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if not parallel_state.is_pipeline_last_stage():
tensor_send_next = output_tensor
if not parallel_state.is_pipeline_first_stage():
recv_prev = True
if model_chunk_id == 1:
if not parallel_state.is_pipeline_first_stage():
tensor_send_prev = output_tensor
if not parallel_state.is_pipeline_last_stage():
recv_next = True
if config.timers is not None:
config.timers('forward-send-forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not async_op),
config=config
)
if config.timers is not None:
config.timers('forward-send-forward-recv').stop()
if model_chunk_id == 0:
if not parallel_state.is_pipeline_first_stage():
return tensor_recv_prev, fwd_wait_handles
else:
return None, fwd_wait_handles
else:
if not parallel_state.is_pipeline_last_stage():
return tensor_recv_next, fwd_wait_handles
else:
return None, fwd_wait_handles
def send_forward_recv_slave_forward(
output_tensor: torch.Tensor,
tensor_shape: Shape,
config: ModelParallelConfig,
model_chunk_id,
async_op=False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev, recv_next = False, False
tensor_send_next, tensor_send_prev = None, None
if model_chunk_id == 0:
if parallel_state.is_pipeline_last_stage():
return None, None
tensor_send_next = output_tensor
recv_next = True
if model_chunk_id == 1:
if parallel_state.is_pipeline_first_stage():
return None, None
tensor_send_prev = output_tensor
recv_prev = True
if config.timers is not None:
config.timers('forward-send-slave-forward-recv', log_level=2).start()
tensor_recv_prev, tensor_recv_next, fwd_wait_handles = _communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not async_op),
config=config,
)
if config.timers is not None:
config.timers('forward-send-slave-forward-recv').stop()
if model_chunk_id == 0:
return tensor_recv_next, fwd_wait_handles
else:
return tensor_recv_prev, fwd_wait_handles
def generate_dualpipev_schedule(pp_size, num_microbatches):
num_microbatches = num_microbatches * 2
num_warmup_stages = [0] * pp_size
num_interleaved_forward_stages = [0] * pp_size
num_1b1w1f_stages = [0] * pp_size
num_overlap_stages = [0] * pp_size
num_1b1overlap_stages = [0] * pp_size
num_interleaved_backward_stages = [0] * pp_size
num_cooldown_stages = [0] * pp_size
pp_size *= 2
for i in range(pp_size // 2):
num_warmup_stages[i] = pp_size - 2 - i * 2
num_interleaved_forward_stages[i] = i + 1 # 每个单位是一组1f1f
num_1b1w1f_stages[i] = pp_size // 2 - i - 1
num_overlap_stages[i] = num_microbatches - pp_size * 2 + i * 2 + 2
num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2
num_interleaved_backward_stages[i] = i + 1
num_cooldown_stages[i] = [i + 1, pp_size - 2 * i - 2, i + 1]
schedule_all_stages = {
'warmup': num_warmup_stages,
'interleaved_forward': num_interleaved_forward_stages,
'1b1w1f': num_1b1w1f_stages,
'overlap': num_overlap_stages,
'1b1overlap': num_1b1overlap_stages,
'interleaved_backward': num_interleaved_backward_stages,
'cooldown': num_cooldown_stages
}
return schedule_all_stages
def pretrain_gpt_forward_step_dualpipe(data_iterator, model: GPTModel, extra_block_kwargs=None):
from megatron.training import get_timers
from functools import partial
from pretrain_gpt import get_batch, loss_func
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
if extra_block_kwargs is not None:
# excute forward backward overlaping
output_tensor, model_graph, pp_comm_output = \
model(tokens, position_ids, attention_mask, labels=labels,
extra_block_kwargs=extra_block_kwargs)
return (output_tensor, model_graph, pp_comm_output), partial(loss_func, loss_mask)
else:
output_tensor, model_graph = model(
tokens, position_ids, attention_mask, labels=labels)
return (output_tensor, model_graph), partial(loss_func, loss_mask)
def forward_step_no_model_graph(
forward_step_func,
model_chunk_id,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
):
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
num_tokens = torch.tensor(0, dtype=torch.int)
if is_dualpipev_last_stgae:
if not collect_non_loss_data:
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def backward_step_with_model_graph(input_tensor, output_tensor, output_tensor_grad, model_type, config, model_graph=None):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if config.timers is not None:
config.timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None and config.grad_scale_func is not None and model_graph is None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if config.deallocate_pipeline_outputs:
if model_graph is None:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
layer_output_grad = gpt_model_backward(
output_tensor_grad[0], model_graph)
else:
torch.autograd.backward(
output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
if model_graph is not None:
input_tensor_grad.append(layer_output_grad)
else:
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
if config.timers is not None:
config.timers('backward-compute').stop()
return input_tensor_grad
def forward_step_with_model_graph(
forward_step_func,
model_chunk_id,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
extra_block_kwargs=None,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable): The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally.
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator): The data iterator.
model (nn.Module): The model to perform the forward step on.
num_microbatches (int): The number of microbatches.
input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step.
forward_data_store (list): The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object): The configuration object.
collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional): The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = pretrain_gpt_forward_step_dualpipe(
data_iterator, model, extra_block_kwargs)
else:
output_tensor, loss_func = pretrain_gpt_forward_step_dualpipe(
data_iterator, model, checkpoint_activations_microbatch, extra_block_kwargs
)
num_tokens = torch.tensor(0, dtype=torch.int)
if is_dualpipev_last_stgae(model_chunk_id):
if not collect_non_loss_data:
next_info = None
if isinstance(output_tensor, tuple):
# use pp overlaping,
if len(output_tensor) == 2:
output_tensor, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor, model_graph, next_info = output_tensor
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
output_tensor = (output_tensor, model_graph, next_info) if next_info is not None else (
output_tensor, model_graph)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(LOSS_BACKWARD_SCALE)
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
shared_embedding = None
def get_shared_embedding_from_dual_chunk():
assert shared_embedding is not None
return shared_embedding
def set_shared_embedding_from_dual_chunk(model1, model2):
global shared_embedding
if shared_embedding is not None:
return
if model1.module.module.pre_process:
shared_embedding = model1.module.module.embedding.word_embeddings.weight
elif model2.module.module.pre_process:
shared_embedding = model2.module.module.embedding.word_embeddings.weight
def forward_backward_pipelining_with_cutinhalf(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
args = get_args()
args.moe_fb_overlap = True
args.dualpipe_no_dw_detach = True
set_shared_embedding_from_dual_chunk(model[0], model[1])
assert (
isinstance(model, list) and len(model) == 2
), 'Dualpipe Schedule only support chunk model for two consecutive chunks'
assert (
isinstance(data_iterator, list) and len(data_iterator) == 2
), 'Dualpipe Schedule only support two data_iterators'
config = get_model_config(model[0])
config.batch_p2p_comm = False
# Needed only when gradients are finalized in M-Core
if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(config, model)
if config.timers is not None:
config.timers('forward-backward',
log_level=1).start(barrier=config.barrier_with_L1_time)
# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Compute number of steps for each stage
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank = parallel_state.get_pipeline_model_parallel_rank()
schedule = generate_dualpipev_schedule(pp_size, num_microbatches)
model_type = get_model_type(model[0])
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
input_tensors = [[], []]
output_tensors = [[], []]
model_graphs = [[], []]
logits_inputs = []
forward_data_store = []
master_chunk_id = 0
slave_chunk_id = 1
master_cur_microbatch = 0
slave_cur_microbatch = num_microbatches
master_microbatch_max = num_microbatches
slave_microbatch_max = num_microbatches * 2
set_dualpipe_chunk(master_chunk_id)
checkpoint_activations_microbatch = None
def forward_step_helper(model_chunk_id, current_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=False, extra_block_kwargs=None):
input_tensor = input_tensors[model_chunk_id][-1][1]
output_tensor, num_tokens = forward_step_with_model_graph(
forward_step_func,
model_chunk_id,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch,
current_microbatch=current_microbatch,
extra_block_kwargs=extra_block_kwargs
)
if isinstance(output_tensor, tuple):
if len(output_tensor) == 2:
output_tensor_, model_graph = output_tensor
elif len(output_tensor) == 3:
output_tensor_, model_graph, pp_comm_output = output_tensor
if is_dualpipev_last_stgae(model_chunk_id):
logits_inputs.append(
model_graph.layer_graphs[-1].unperm2_graph[1])
model_graphs[model_chunk_id].append(model_graph)
else:
output_tensor_ = output_tensor
output_tensors[model_chunk_id].append(output_tensor_)
if extra_block_kwargs is not None:
input_tensors[1 - model_chunk_id].pop(0)
output_tensors[1 - model_chunk_id].pop(0)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def check_pipeline_stage(model_chunk_id, fwd_send_only):
send_next, recv_next, send_prev, recv_prev = True, True, True, True
if parallel_state.is_pipeline_first_stage():
send_prev, recv_prev = False, False
if parallel_state.is_pipeline_last_stage():
send_next, recv_next = False, False
if model_chunk_id == 0:
return P2PCommParams(send_next=send_next, recv_next=not fwd_send_only and recv_next), P2PCommParams(send_next=send_next, recv_next=recv_next)
else:
return P2PCommParams(send_prev=send_prev, recv_prev=not fwd_send_only and recv_prev), P2PCommParams(send_prev=send_prev, recv_prev=recv_prev)
input_tensor = recv_forward(tensor_shape, config, master_chunk_id)[0]
fwd_wait_handles_warmup = None
# Run warmup forward passes
for i in range(schedule['warmup'][rank]):
if args.moe_fb_overlap:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensor_warmup, _ = forward_step_helper(master_chunk_id, master_cur_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=(i == 0))
else:
output_tensor_warmup, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=(i == 0),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
master_cur_microbatch += 1
if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id)
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
else:
input_tensor, _ = recv_forward(
tensor_shape, config, master_chunk_id)
fwd_wait_handles_warmup = send_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id, async_op=True)
# Run interleaved forward passes for two model chunk
fwd_wait_handles = None
fwd_wait_handles_slave_chunk = None
fwd_wait_handles_send = None
for i in range(schedule['interleaved_forward'][rank]):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
is_first_microbatch = parallel_state.is_pipeline_last_stage() and (i == 0)
set_dualpipe_chunk(master_chunk_id)
if args.moe_fb_overlap:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensor, _ = forward_step_helper(master_chunk_id, master_cur_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch)
else:
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch,
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
master_cur_microbatch += 1
if not parallel_state.is_pipeline_last_stage() and fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if parallel_state.is_pipeline_last_stage():
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
else:
input_tensor_slave_chunk, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles_warmup is not None:
for req in fwd_wait_handles_warmup:
req.wait()
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
set_dualpipe_chunk(slave_chunk_id)
if args.moe_fb_overlap:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensor_slave_chunk, _ = forward_step_helper(
slave_chunk_id, slave_cur_microbatch, checkpoint_activations_microbatch)
else:
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch,
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp = False
firstFB_no_overlp_handle = None
# last rank not overlap first F&B
if parallel_state.is_pipeline_last_stage():
firstFB_no_overlp = True
output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward(
tensor_shape, config, slave_chunk_id, async_op=True)
else:
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True)
if not parallel_state.is_pipeline_last_stage():
output_tensor_send = output_tensor
fwd_wait_handles_send = send_forward(
output_tensor_send, tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
# Run 1b1w1f stages for slave chunk
bwd_wait_handles = None
for _ in range(schedule['1b1w1f'][rank]):
WeightGradStore.start_decouple()
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(slave_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[slave_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
model_graph = model_graphs[slave_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
WeightGradStore.end_decouple()
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
# If asynchronous, the memory will rise.
input_tensor_slave_chunk, recv_forward_handle = recv_forward(
tensor_shape, config, slave_chunk_id)
# 1w: Weight Grad Compute
WeightGradStore.pop()
if recv_forward_handle is not None:
for req in recv_forward_handle:
req.wait()
recv_forward_handle = None
# 1F: Forward pass
set_dualpipe_chunk(slave_chunk_id)
if args.moe_fb_overlap:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensor_slave_chunk, _ = forward_step_helper(
slave_chunk_id, slave_cur_microbatch, checkpoint_activations_microbatch)
else:
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True)
fwd_wait_handles_recv = None
# Run overlaping f&bw stages
fwd_model_chunk_id = master_chunk_id
bwd_model_chunk_id = slave_chunk_id
for _ in range(schedule['overlap'][rank] + schedule['1b1overlap'][rank] + schedule['interleaved_backward'][rank]):
only_bwd = False
if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max:
only_bwd = True
if fwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch == slave_microbatch_max:
only_bwd = True
if args.moe_fb_overlap and not firstFB_no_overlp:
if not only_bwd:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
if fwd_wait_handles_recv is not None:
for req in fwd_wait_handles_recv:
req.wait()
fwd_wait_handles_recv = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if not parallel_state.is_pipeline_last_stage() or fwd_model_chunk_id == master_chunk_id:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
fwd_send_only = False
if fwd_model_chunk_id == slave_chunk_id and master_cur_microbatch == master_microbatch_max:
fwd_send_only = True
extra_block_kwargs = {}
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
extra_block_kwargs.setdefault(
'bwd_model_grad', input_tensor_grad)
else:
extra_block_kwargs.setdefault(
'bwd_model_grad', output_tensor_grad_bwd)
fwd_pp_comm_params, bwd_pp_comm_params = check_pipeline_stage(
fwd_model_chunk_id, fwd_send_only)
fwd_pp_comm_params.config, bwd_pp_comm_params.config = config, config
fwd_pp_comm_params.tensor_shape, bwd_pp_comm_params.tensor_shape = tensor_shape, tensor_shape
extra_block_kwargs.setdefault(
'bwd_model_graph', model_graphs[bwd_model_chunk_id].pop(0))
extra_block_kwargs.setdefault(
'pp_comm_params', fwd_pp_comm_params)
extra_block_kwargs.setdefault(
'bwd_pp_comm_params', bwd_pp_comm_params)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensor, model_graph, pp_comm_output = forward_step_helper(fwd_model_chunk_id, fwd_microbatch, checkpoint_activations_microbatch,
extra_block_kwargs=extra_block_kwargs)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
output_tensor_grad_bwd = pp_comm_output.input_tensor_grad
else:
input_tensor, fwd_wait_handles = pp_comm_output.input_tensor, pp_comm_output.fwd_wait_handles
output_tensor_grad_bwd, bwd_wait_handles = pp_comm_output.output_tensor_grad, pp_comm_output.bwd_wait_handles
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
else:
slave_cur_microbatch += 1
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
else:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, fwd_wait_handles_recv = recv_forward(
tensor_shape, config, slave_chunk_id, async_op=True)
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
else:
firstFB_no_overlp = False
if not only_bwd:
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
if args.moe_fb_overlap:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensor, _ = forward_step_helper(
fwd_model_chunk_id, fwd_microbatch, checkpoint_activations_microbatch)
else:
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
fwd_model_chunk_id,
data_iterator[fwd_model_chunk_id],
model[fwd_model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if firstFB_no_overlp_handle is not None:
for req in firstFB_no_overlp_handle:
req.wait()
firstFB_no_overlp_handle = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
fwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
# only run backward
else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if args.moe_fb_overlap:
if is_dualpipev_last_stgae(bwd_model_chunk_id):
input_tensor_bwd = logits_inputs.pop(0)
output_tensor_bwd = output_tensors[bwd_model_chunk_id][0]
model_graph = None
output_tensor_grad_bwd = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
# Run cooldown phases
merged_input_tensors = []
merged_output_tensors = []
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
if len(input_tensors[bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[1 - bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size):
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req in bwd_wait_handles_recv:
req.wait()
bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
if not args.dualpipe_no_dw_detach:
WeightGradStore.start_decouple()
if args.moe_fb_overlap:
model_graph = model_graphs[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_with_model_graph(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config, model_graph
)
else:
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if not args.dualpipe_no_dw_detach:
WeightGradStore.end_decouple()
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
elif i >= schedule['cooldown'][rank][0] - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
tensor_shape, config, bwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
WeightGradStore.flush_chunk_grad()
if i >= schedule['cooldown'][rank][0] - 1:
WeightGradStore.pop_single()
for _ in range(schedule['cooldown'][rank][2] - 1):
WeightGradStore.pop_single()
assert WeightGradStore.weight_grad_queue.empty()
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
bwd_wait_handles = None
if config.finalize_model_grads_func is not None and not forward_only:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute(config, embedding_module)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config.finalize_model_grads_func(
model, total_num_tokens if config.calculate_per_token_loss else None
)
return forward_data_store
from .modules.layers import linear_backward_wgrad_detach, ColumnParallelLinear, RowParallelLinear
from .modules.experts import group_mlp_forward_detach
from .transformer_layer import transformer_layer_forward_backward_overlaping
from .gpt_model import gpt_model_forward_backward_overlaping
from .vpp_schedules import forward_backward_pipelining_with_interleaving
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
def _make_param_hook(
self,
param: torch.nn.Parameter,
param_to_buffer,
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def param_hook(*unused):
if param.requires_grad and not getattr(param, 'skip_grad_accum', False):
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
# Maybe should called after weightgradstore.pop()
if self.ddp_config.overlap_grad_reduce:
param_to_buffer[param].register_grad_ready(param)
if getattr(param, 'skip_grad_accum', False):
param.skip_grad_accum = False
return param_hook
\ No newline at end of file
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import logging
from typing import Dict, Literal, Optional, Tuple, Union, List
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.packed_seq_params import PackedSeqParams
from .transformer_block import (
transformer_block_backward, transformer_block_forward_backward_overlaping,
transformer_block_forward
)
from .modules.utils import (
LayerGraph, detach_tensor, run_graph_backward
)
class ModelGraph:
def __init__(
self,
layer_graphs: List[LayerGraph],
block_output,
preprocess_graph: Tensor = None,
preprocess_detached_output: Tensor = None,
):
self.preprocess_graph = (preprocess_graph, preprocess_detached_output)
self.layer_graphs = layer_graphs
self.block_output = block_output
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if decoder_input is not None:
preprocess_graph = None
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
preprocess_graph = decoder_input
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
preprocess_graph = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
detached_block_input = detach_tensor(decoder_input)
# Run decoder.
hidden_states, layer_graphs = transformer_block_forward(
self.decoder,
hidden_states=detached_block_input,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states, ModelGraph(layer_graphs, hidden_states, preprocess_graph, detached_block_input)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
logits, _ = self.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
logits = logits.transpose(0, 1).contiguous()
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return logits, graph
loss = self.compute_language_model_loss(labels, logits)
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return loss, graph
def gpt_model_backward(
model_grad,
model_graph: ModelGraph,
):
block_input_grad = transformer_block_backward(model_grad, model_graph.layer_graphs)
if model_graph.preprocess_graph[0] is not None:
run_graph_backward(model_graph.preprocess_graph, block_input_grad, keep_graph=True, keep_grad=True)
return None
else:
return block_input_grad
def gpt_model_forward_backward_overlaping(
fwd_model,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
):
if extra_block_kwargs is None or extra_block_kwargs['bwd_model_graph'] is None:
return gpt_model_forward(
fwd_model, input_ids, position_ids, attention_mask, decoder_input, labels, inference_params,
packed_seq_params, extra_block_kwargs
)
bwd_model_grad, bwd_model_graph = extra_block_kwargs['bwd_model_grad'], extra_block_kwargs['bwd_model_graph'] # Fwd Model Decoder embedding.
if decoder_input is not None:
preprocess_graph = None
elif fwd_model.pre_process:
decoder_input = fwd_model.embedding(input_ids=input_ids, position_ids=position_ids)
preprocess_graph = decoder_input
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
preprocess_graph = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if fwd_model.position_embedding_type == 'rope':
rotary_seq_len = fwd_model.rotary_pos_emb.get_rotary_seq_len(
inference_params, fwd_model.decoder, decoder_input, fwd_model.config
)
rotary_pos_emb = fwd_model.rotary_pos_emb(rotary_seq_len)
detached_block_input = detach_tensor(decoder_input)
# Run transformer block fwd & bwd overlaping
(hidden_states, layer_graphs), block_input_grad, pp_comm_output \
= transformer_block_forward_backward_overlaping(
fwd_model.decoder,
detached_block_input,
attention_mask,
bwd_model_grad,
bwd_model_graph.layer_graphs,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
pp_comm_params=extra_block_kwargs['pp_comm_params'],
bwd_pp_comm_params=extra_block_kwargs['bwd_pp_comm_params']
)
if bwd_model_graph.preprocess_graph[0] is not None:
run_graph_backward(bwd_model_graph.preprocess_graph, block_input_grad, keep_grad=True, keep_graph=True)
if not fwd_model.post_process:
return hidden_states, ModelGraph(layer_graphs, hidden_states, preprocess_graph,
detached_block_input), pp_comm_output
# logits and loss
output_weight = None
if fwd_model.share_embeddings_and_output_weights:
output_weight = fwd_model.shared_embedding_or_output_weight()
logits, _ = fwd_model.output_layer(hidden_states, weight=output_weight)
if labels is None:
# [s b h] => [b s h]
logits = logits.transpose(0, 1).contiguous()
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return logits, graph, pp_comm_output
loss = fwd_model.compute_language_model_loss(labels, logits)
graph = ModelGraph(
layer_graphs, hidden_states, preprocess_graph, detached_block_input
)
return loss, graph, pp_comm_output
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
from megatron.training import get_args
from mindspeed.core.transformer.moe.comm_utils import async_all_to_all
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
AsyncAll2All_INPUT = []
AsyncAll2All_OUTPUT = []
def set_async_alltoall_inputs(*args):
AsyncAll2All_INPUT.append(args)
def get_async_alltoall_outputs():
return AsyncAll2All_OUTPUT.pop(0)
def launch_async_all2all():
global AsyncAll2All_INPUT
global AsyncAll2All_OUTPUT
if len(AsyncAll2All_INPUT) > 0:
input_, input_splits, output_splits, group = AsyncAll2All_INPUT.pop(0)
_, output, a2a_handle = async_all_to_all(
input_,
input_splits,
output_splits,
group
)
AsyncAll2All_OUTPUT.append((output, a2a_handle))
def launch_async_all2all_hook(_):
launch_async_all2all()
def attention_forward(
self,
hidden_states,
residual,
attention_mask=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
recompute_norm=False
):
# Optional Input Layer norm
def pre_norm(hidden_states):
args = get_args()
input_layernorm_output = self.input_layernorm(hidden_states)
if getattr(args, 'input_layernorm_in_fp32', False):
input_layernorm_output = input_layernorm_output.float()
return input_layernorm_output
if recompute_norm:
self.norm_ckpt1 = CheckpointWithoutOutput()
input_layernorm_output = self.norm_ckpt1.checkpoint(pre_norm, False, hidden_states)
else:
input_layernorm_output = pre_norm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_params=packed_seq_params,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
if recompute_norm:
self.norm_ckpt1.discard_output()
hidden_states.register_hook(self.norm_ckpt1.recompute)
return hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment