You need to sign in or sign up before continuing.
Commit d05234e0 authored by dongcl's avatar dongcl
Browse files

init dcu_megatron

parent f8eedf6b
from .adaptor import megatron_adaptor
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import sys
import types
import argparse
import torch
class MegatronAdaptation:
"""
A module manager supports adaptation registration, application and execution.
"""
_patch_info_collection = {}
_args = None
@classmethod
def execute(cls):
"""
Execute adaptations.
"""
for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute()
MegatronAdaptation.apply()
# MegatronAdaptation.post_execute()
@classmethod
def register(cls, orig_func_name, new_func=None, force_patch=False, create_dummy=False):
"""
Register adaptations into collection.
"""
if orig_func_name not in cls._patch_info_collection:
from .patch_utils import Patch
cls._patch_info_collection[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
else:
cls._patch_info_collection.get(orig_func_name).set_patch_func(new_func, force_patch)
@classmethod
def apply(cls):
"""
Apply adaptations.
"""
for patch in cls._patch_info_collection.values():
patch.apply_patch()
@classmethod
def post_execute(cls):
"""
Execute after other adaptations.
"""
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.transformer_block import TransformerBlock
class MegatronAdaptationABC:
"""
Abstract class for adaptation.
"""
@abc.abstractmethod
def execute(self):
"""
Do Adaptation
"""
class CoreAdaptation(MegatronAdaptationABC):
"""
Adaptations for models in Megatron-LM Core structure.
"""
def execute(self):
self.patch_core_distributed()
self.patch_core_models()
self.patch_core_transformers()
self.patch_tensor_parallel()
self.patch_training()
self.patch_miscellaneous()
def patch_core_distributed(self):
# Mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init,
shared_embedding_or_mtp_embedding_weight
)
# Embedding
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
from megatron.core.models.gpt.gpt_model import GPTModel
setattr(GPTModel, 'shared_embedding_or_mtp_embedding_weight', shared_embedding_or_mtp_embedding_weight)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper, transformer_block_forward
from ..core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
# Transformer block
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.forward',
transformer_block_forward)
# Transformer config
MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig',
TransformerConfig)
# Transformer config
MegatronAdaptation.register('megatron.core.transformer.transformer_config.MLATransformerConfig',
MLATransformerConfig)
def patch_tensor_parallel(self):
from ..core import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
def patch_miscellaneous(self):
from ..training.arguments import parse_args
MegatronAdaptation.register('megatron.training.arguments.parse_args', parse_args)
class LegacyAdaptation(MegatronAdaptationABC):
"""
Adaptations for models in legacy structure.
"""
def execute(self):
pass
MegatronAdaptation.execute()
import importlib
import sys
import types
def get_func_name(func):
if isinstance(func, str):
return func
return '.'.join((func.__module__, func.__qualname__))
def dummy_function_wrapper(func_name):
def dummy_function(*args, **kwargs):
raise RuntimeError('function {} no exist'.format(func_name))
return dummy_function
class Patch:
def __init__(self, orig_func_name, new_func, create_dummy):
split_name = orig_func_name.rsplit('.', 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None
else:
self.orig_module_name, self.orig_func_name = split_name
self.orig_module = None
self.orig_func = None
self.patch_func = None
self.wrappers = []
if new_func is None:
new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_id(self):
return id(self.orig_func)
@property
def patch_func_id(self):
return id(self.patch_func)
def set_patch_func(self, new_func, force_patch=False):
if hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')):
self.wrappers.append(new_func)
else:
if self.patch_func and not force_patch:
raise RuntimeError('the patch of {} exist !'.format(self.orig_func_name))
self.patch_func = new_func
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
self.orig_module, self.orig_func = Patch.parse_path(self.orig_module_name, self.orig_func_name, self.create_dummy)
final_patch_func = self.orig_func
if self.patch_func is not None:
final_patch_func = self.patch_func
for wrapper in self.wrappers:
final_patch_func = wrapper(final_patch_func)
if self.orig_func_name is not None:
setattr(self.orig_module, self.orig_func_name, final_patch_func)
for key, value in sys.modules.copy().items():
if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
and id(getattr(value, self.orig_func_name)) == self.orig_func_id:
setattr(value, self.orig_func_name, final_patch_func)
self.is_applied = True
@staticmethod
def parse_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
modules = module_path.split('.')
for i in range(1, len(modules) + 1):
parent = '.'.join(modules[:i - 1])
path = '.'.join(modules[:i])
try:
importlib.import_module(path)
except ModuleNotFoundError as e:
if not parent or not hasattr(importlib.import_module(parent), modules[i - 1]):
if not create_dummy:
raise ModuleNotFoundError(e) from e
sys.modules[path] = types.ModuleType(path)
sys.modules[path].__file__ = 'dcu_megatron.dummy_module.py'
sys.modules[path].__spec__ = ModuleSpec(path, None)
if parent:
setattr(importlib.import_module(parent), modules[i - 1], sys.modules[path])
else:
module = getattr(importlib.import_module(parent), modules[i - 1])
if hasattr(module, function_name):
return module, getattr(module, function_name)
elif create_dummy:
return module, dummy_function_wrapper(function_name)
else:
raise RuntimeError('no exist {} of {}'.format(function_name, module))
if function_name is not None and not hasattr(sys.modules[module_path], function_name):
setattr(sys.modules[module_path], function_name, None)
return sys.modules[module_path], getattr(sys.modules[module_path], function_name) if function_name is not None else None
class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_name, new_func=None, force_patch=False, create_dummy=False):
if orig_func_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
else:
MegatronPatchesManager.patches_info.get(orig_func_name).set_patch_func(new_func, force_patch)
@staticmethod
def apply_patches():
for patch in MegatronPatchesManager.patches_info.values():
patch.apply_patch()
from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from .transformer.transformer_block import transformer_block_init_wrapper, transformer_block_forward
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from typing import List
import torch
from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _unshard_if_dtensor, _reshard_if_dtensor
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_attr_wrapped_model
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if hasattr(model_module,
"share_mtp_embedding_and_output_weight") and model_module.share_mtp_embedding_and_output_weight:
weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from typing import Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
def language_model_embedding_init_func(
self,
config: TransformerConfig,
vocab_size: int,
max_sequence_length: int,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
skip_weight_param_allocation: bool = False
):
"""Patch language model embeddings init."""
super(LanguageModelEmbedding, self).__init__(config=config)
self.config: TransformerConfig = config
self.vocab_size: int = vocab_size
self.max_sequence_length: int = max_sequence_length
self.add_position_embedding: bool = position_embedding_type == 'learned_absolute'
self.num_tokentypes = num_tokentypes
self.scatter_to_sequence_parallel = scatter_to_sequence_parallel
self.reduce_scatter_embeddings = (
(not self.add_position_embedding)
and self.num_tokentypes <= 0
and self.config.sequence_parallel
and self.scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
num_embeddings=self.vocab_size,
embedding_dim=self.config.hidden_size,
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
skip_weight_param_allocation=skip_weight_param_allocation
)
# Position embedding (serial).
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
self.max_sequence_length, self.config.hidden_size
)
# Initialize the position embeddings.
if self.config.perform_initialization:
self.config.init_method(self.position_embeddings.weight)
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(
self.num_tokentypes, self.config.hidden_size
)
# Initialize the token-type embeddings.
if self.config.perform_initialization:
self.config.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout)
def language_model_embedding_forward(self,
input_ids: Tensor,
position_ids: Tensor,
tokentype_ids: int = None,
weight: Tensor = None) -> Tensor:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if weight is None:
if self.word_embeddings.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.word_embeddings.weight
word_embeddings = self.word_embeddings(input_ids, weight)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
else:
embeddings = word_embeddings
if not self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
embeddings = embeddings + tokentype_embedding
else:
assert self.tokentype_embeddings is None
# If the input flag for fp32 residual connection is set, convert for float.
if self.config.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
if self.config.sequence_parallel:
if not self.reduce_scatter_embeddings and self.scatter_to_sequence_parallel:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if self.config.clone_scatter_output_in_embedding and self.scatter_to_sequence_parallel:
embeddings = embeddings.clone()
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings
import logging
from typing import Literal, Optional
from functools import wraps
from collections import OrderedDict
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from dcu_megatron.core.transformer.transformer_config import TransformerConfig
def gpt_model_init(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
vocab_size: int,
max_sequence_length: int,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
share_embeddings_and_output_weights: bool = False,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
rope_scaling: bool = False,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
mtp_spec: ModuleSpec = None
) -> None:
super(GPTModel, self).__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.position_embedding_type = position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self.model_type = ModelType.encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self.max_position_embeddings = max_sequence_length
self.rotary_percent = rotary_percent
self.rotary_base = rotary_base
self.rotary_scaling = rope_scaling
if self.pre_process:
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=position_embedding_type,
scatter_to_sequence_parallel=scatter_embedding_sequence_parallel,
)
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
rope_scaling=rope_scaling,
use_cpu_initialization=self.config.use_cpu_initialization,
)
# Transformer.
self.decoder = TransformerBlock(
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process
)
# Output
if post_process:
if self.config.defer_embedding_wgrad_compute:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=False,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.pre_process
and self.share_embeddings_and_output_weights,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
# add mtp
self.mtp_spec: ModuleSpec = mtp_spec
self.num_nextn_predict_layers = self.config.num_nextn_predict_layers
self.share_mtp_embedding_and_output_weight = self.config.share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = self.config.recompute_mtp_norm
self.recompute_mtp_layer = self.config.recompute_mtp_layer
self.mtp_loss_scale = self.config.mtp_loss_scale
if self.post_process and self.training and self.num_nextn_predict_layers:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
)
for i in range(self.num_nextn_predict_layers)
]
)
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
if self.num_nextn_predict_layers and (self.pre_process or self.post_process):
setup_mtp_embeddings(self)
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
"""
assert self.num_nextn_predict_layers > 0
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.mtp_layers[0].embedding.word_embeddings.weight
return None
def setup_mtp_embeddings(self):
"""
Share embedding layer in mtp layer.
"""
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
# Set `is_embedding_or_output_parameter` attribute.
for i in range(self.num_nextn_predict_layers):
if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None:
self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if not self.share_mtp_embedding_and_output_weight:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_mtp_embedding_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_mtp_embedding_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
for i in range(self.num_nextn_predict_layers):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0)
self.mtp_layers[i].embedding.word_embeddings.weight.shared = True
self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_mtp_embedding_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0:
return (
[input_ids],
[labels],
[position_ids],
[attention_mask],
)
return (
tensor_slide(input_ids, self.num_nextn_predict_layers),
tensor_slide(labels, self.num_nextn_predict_layers),
generate_nextn_position_ids(position_ids, self.num_nextn_predict_layers),
# not compatible with ppo attn_mask
tensor_slide(attention_mask, self.num_nextn_predict_layers, dims=[-2, -1]),
)
def generate_nextn_position_ids(tensor, slice_num):
slides = tensor_slide(tensor, slice_num)
if slides[0] is None:
return slides
for idx in range(1, len(slides)):
slides[idx] = regenerate_position_ids(slides[idx], idx)
return slides
def regenerate_position_ids(tensor, offset):
if tensor is None:
return None
tensor = tensor.clone()
for i in range(tensor.size(0)):
row = tensor[i]
zero_mask = (row == 0) # 两句拼接情形
if zero_mask.any():
first_zero_idx = torch.argmax(zero_mask.int()).item()
tensor[i, :first_zero_idx] = torch.arange(first_zero_idx)
else:
tensor[i] = tensor[i] - offset
return tensor
def gpt_model_forward(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
input_ids, labels, position_ids, attention_mask = slice_inputs(
self,
input_ids,
labels,
position_ids,
attention_mask
)
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0])
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask[0],
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
loss = 0
# Multi token prediction module
if self.num_nextn_predict_layers and self.training:
if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight:
output_weight = self.output_layer.weight
output_weight.zero_out_wgrad = True
embedding_weight = self.shared_embedding_or_mtp_embedding_weight() if self.share_mtp_embedding_and_output_weight else None
mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i](
mtp_hidden_states, # [s,b,h]
input_ids[i + 1],
position_ids[i + 1] if position_ids[0] is not None else None,
attention_mask[i + 1] if attention_mask[0] is not None else None,
labels[i + 1] if labels[0] is not None else None,
inference_params,
packed_seq_params,
extra_block_kwargs,
embeding_weight=embedding_weight,
output_weight=output_weight,
)
loss += self.mtp_loss_scale / self.num_nextn_predict_layers * mtp_loss
if (
self.num_nextn_predict_layers
and getattr(self.decoder, "final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.final_layernorm(hidden_states)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids[0],
'position_ids': position_ids[0],
'attention_mask': attention_mask[0],
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels[0] is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss += self.compute_language_model_loss(labels[0], logits)
return loss
from typing import Callable
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import is_torch_min_version
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.core.tensor_parallel.mappings import (
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.utils import VocabUtility
def vocab_parallel_embedding_init(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
skip_weight_param_allocation: bool = False
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = None
@torch.compile(mode='max-autotune-no-cudagraphs')
def vocab_parallel_embedding_forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
import torch
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
_set_cuda_rng_state
)
class CheckpointFunctionWithoutOutput(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, checkpoint, *args):
with torch.no_grad():
outputs = run_function(*args)
# Store everything
ctx.save_for_backward(*detach_variable(args))
checkpoint.ctx = ctx
return outputs
@staticmethod
def backward(ctx, *args):
inputs = ctx.saved_tensors
outputs = ctx.outputs
torch.autograd.backward(outputs, args)
ctx.outputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
return (None, None) + grads
class CheckpointWithoutOutput:
def __init__(self):
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
self.outputs = None
def checkpoint(self, run_function, distribute_saved_activations, *args):
self.run_function = run_function
if distribute_saved_activations:
raise RuntimeError(
"CheckpointFunctionWithoutOutput does not support "
"distribute_saved_activations"
)
#Copy the rng states.
self.fwd_cpu_rng_state = torch.get_rng_state()
self.fwd_cuda_rng_state = torch.cuda.get_rng_state()
self.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
outputs = CheckpointFunctionWithoutOutput.apply(run_function, self, *args)
self.outputs = outputs
if isinstance(self.outputs, torch.Tensor):
self.outputs = (self.outputs,)
return outputs
def discard_output(self):
for output in self.outputs:
output.untyped_storage().resize_(0)
def recompute(self, _):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
# Store the current states.
cur_cpu_rng_state = torch.get_rng_state()
cur_cuda_rng_state = torch.cuda.get_rng_state()
cur_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(self.fwd_cpu_rng_state)
_set_cuda_rng_state(self.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(self.fwd_cuda_rng_state_tracker)
with torch.enable_grad():
outputs = self.run_function(*self.ctx.saved_tensors)
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
# Set the states back to what it was at the start of this function.
torch.set_rng_state(cur_cpu_rng_state)
_set_cuda_rng_state(cur_cuda_rng_state)
get_cuda_rng_tracker().set_states(cur_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
for output, recomputation_output in zip(self.outputs, outputs):
output_size = recomputation_output.untyped_storage().size()
output.untyped_storage().resize_(output_size)
with torch.no_grad():
output.untyped_storage().copy_(recomputation_output.untyped_storage())
self.ctx.outputs = outputs
self.outputs = None
self.ctx = None
import warnings
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer import ModuleSpec
from .multi_token_predictor import (
MultiTokenPredicationSubmodules,
MultiTokenPredictor
)
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TENorm
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False):
"""
Multi Token Predication Layer Specification.
"""
use_te = use_te & HAVE_TE
mtp_spec = ModuleSpec(
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl,
output_layer=None,
)
)
return mtp_spec
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import logging
from dataclasses import dataclass
from typing import Union, Optional, Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel, InferenceParams
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from ...tensor_parallel.random import CheckpointWithoutOutput
@dataclass
class MultiTokenPredicationSubmodules:
embedding: Union[ModuleSpec, type] = None
output_layer: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
final_layernorm: Union[ModuleSpec, type] = None
class MultiTokenPredictor(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredicationSubmodules,
vocab_size: int,
max_sequence_length: int,
layer_number: int = 1,
hidden_dropout: float = None,
pre_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
share_mtp_embedding_and_output_weight=True,
recompute_mtp_norm=False,
recompute_mtp_layer=False,
add_output_layer_bias=False
):
super().__init__(config=config)
self.config = config
self.submodules = submodules
self.layer_number = layer_number
self.hidden_dropout = hidden_dropout
self.hidden_size = self.config.hidden_size
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type
# share with main model
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.eh_proj = build_module(
self.submodules.eh_proj,
self.hidden_size + self.hidden_size,
self.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='eh',
)
self.transformer_layer = build_module(
self.submodules.transformer_layer,
config=self.config,
)
if self.submodules.final_layernorm:
self.final_layernorm = build_module(
self.submodules.final_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=self.add_output_layer_bias,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward(
self,
hidden_input_ids: Tensor,
embed_input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
embeding_weight: Optional[torch.Tensor] = None,
output_weight: Optional[torch.Tensor] = None,
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = self.embedding(
input_ids=embed_input_ids,
position_ids=position_ids,
weight=embeding_weight,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
rotary_seq_len = decoder_input.size(0)
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
self.hnorm_ckpt = CheckpointWithoutOutput()
hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids)
else:
enorm_output = self.enorm(decoder_input)
hnorm_output = self.hnorm(hidden_input_ids)
# [s, b, h] -> [s, b, 2h]
hidden_states = torch.concat(
[hnorm_output,
enorm_output],
dim=-1
)
if self.recompute_layer_norm:
self.enorm_ckpt.discard_output()
self.hnorm_ckpt.discard_output()
hidden_states.register_hook(self.enorm_ckpt.recompute)
hidden_states.register_hook(self.hnorm_ckpt.recompute)
# hidden_states -> [s, b, h]
hidden_states, _ = self.eh_proj(hidden_states)
if self.config.tensor_model_parallel_size > 1:
hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states)
if self.config.sequence_parallel:
hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states)
if self.recompute_mtp_layer:
hidden_states, context = tensor_parallel.checkpoint(
self.transformer_layer,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
inference_params,
packed_seq_params,
)
else:
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
# Final layer norm.
if self.final_layernorm is not None:
if self.recompute_layer_norm:
self.finalnorm_ckpt = CheckpointWithoutOutput()
finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states)
else:
finalnorm_output = self.final_layernorm(hidden_states)
else:
finalnorm_output = hidden_states
logits, _ = self.output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output()
logits.register_hook(self.finalnorm_ckpt.recompute)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return hidden_states, loss
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
\ No newline at end of file
from contextlib import nullcontext
from typing import Optional
from functools import wraps
import torch
from torch import Tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import make_viewless_tensor
try:
from megatron.core.extensions.transformer_engine import TEDelayedScaling
HAVE_TE = True
except ImportError:
HAVE_TE = False
def transformer_block_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
self.move_final_norm_out_of_block = getattr(config, "num_nextn_predict_layers", 0) > 0
return wrapper
def transformer_block_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
attention_bias: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: Tensor = None,
):
"""
Perform the forward pass through the transformer block.
This method handles the core computation of the transformer, including
self-attention, optional cross-attention, and feed-forward operations.
Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Update the inference parameters with the current batch size in case it is variable
if inference_params and not self.training:
inference_params.current_batch_size = hidden_states.size(1)
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red
)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context, fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
for l_no, layer in enumerate(self.layers):
with self.offload_context:
layer.use_cudagraph = True
if (len(self.cuda_graphs) == 0) or (not self.training):
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
else:
# CUDA graph replay for layer `l_no` and microbatch
# `self.current_microbatch`. TransformerEngine versions>=1.10
# allow keyword arguments with CUDA graph. However, CUDA graph
# acccepts only Tensor inputs and Tensor outputs. Hence,
# `inference_params` and `packed_seq_params` are excluded from
# input list while output is limited to `hidden_states`.
cg_index = self.current_microbatch % len(self.cuda_graphs[l_no])
assert not any(
[inference_params, packed_seq_params]
), "CUDA graph accepts only Tensor inputs."
optional_inputs = self.get_cuda_graph_optional_args(
attention_mask,
context,
context_mask,
rotary_pos_emb,
attention_bias,
inference_params,
packed_seq_params,
)
hidden_states = self.cuda_graphs[l_no][cg_index](
hidden_states, **optional_inputs
)
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
# Final layer norm.
if self.final_layernorm is not None:
hidden_states = self.final_layernorm(hidden_states)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True
)
return hidden_states
This diff is collapsed.
import torch
from typing import List, Optional, Union
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
import argparse
from megatron.training.arguments import (
_add_network_size_args,
_add_regularization_args,
_add_training_args,
_add_initialization_args,
_add_learning_rate_args,
_add_checkpointing_args,
_add_mixed_precision_args,
_add_distributed_args,
_add_validation_args,
_add_data_args,
_add_autoresume_args,
_add_biencoder_args,
_add_vision_args,
_add_moe_args,
_add_mla_args,
_add_logging_args,
_add_straggler_detector_args,
_add_inference_args,
_add_transformer_engine_args,
_add_retro_args,
_add_experimental_args,
_add_one_logger_args,
_add_ft_package_args,
_add_config_logger_args,
_add_rerun_machine_args,
)
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_tokenizer_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_mla_args(parser)
parser = _add_mtp_args(parser)
parser = _add_logging_args(parser)
parser = _add_straggler_detector_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
parser = _add_experimental_args(parser)
parser = _add_one_logger_args(parser)
parser = _add_ft_package_args(parser)
parser = _add_config_logger_args(parser)
parser = _add_rerun_machine_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Experimental yaml
if args.yaml_cfg is not None:
from megatron.training.yaml_arguments import load_yaml
assert args.yaml_cfg and not args.use_legacy_models, \
"Yaml config is not supported with legacy models."
args = load_yaml(args.yaml_cfg)
# Args from environment
#args.rank = int(os.getenv('RANK', '0'))
#args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def _add_tokenizer_args(parser):
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--vocab-size', type=int, default=None,
help='Size of vocab before EOD or padding.')
group.add_argument('--extra-vocab-size', type=int, default=0,
help="--extra-vocab-size")
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer',
'Llama2Tokenizer',
'TikTokenizer',
'MultimodalTokenizer',
'NullTokenizer',
'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
group.add_argument('--tiktoken-pattern', type=str, default=None,
help='Which tiktoken pattern to use. Options: [v1, v2]')
group.add_argument('--tiktoken-num-special-tokens', type=int, default=1000,
help='Number of special tokens in tiktoken tokenizer')
group.add_argument('--tiktoken-special-tokens', type=str, nargs='+', default=None,
help='List of tiktoken special tokens, needs to have ["<unk>", "<s>", "</s>"]')
return parser
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale')
group.add_argument('--recompute-mtp-norm', action='store_true', default=False,
help='Multi-Token prediction recompute norm')
group.add_argument('--recompute-mtp-layer', action='store_true', default=False,
help='Multi-Token prediction recompute layer')
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.')
return parser
\ No newline at end of file
from .tokenizer import build_tokenizer
\ No newline at end of file
from transformers import AutoTokenizer
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from megatron.training.tokenizer.tokenizer import (
_BertWordPieceTokenizer,
_GPT2BPETokenizer,
_SentencePieceTokenizer,
_GPTSentencePieceTokenizer,
_HuggingFaceTokenizer,
_Llama2Tokenizer,
CustomTikTokenizer,
_NullTokenizer,
_vocab_size_with_padding
)
def build_tokenizer(args, **kwargs):
"""Initialize tokenizer."""
if args.rank == 0:
print('> building {} tokenizer ...'.format(args.tokenizer_type), flush=True)
# Select and instantiate the tokenizer.
if args.tokenizer_type == 'BertWordPieceLowerCase':
assert args.vocab_file is not None
tokenizer = _BertWordPieceTokenizer(
vocab_file=args.vocab_file, lower_case=True, vocab_extra_ids=args.vocab_extra_ids
)
elif args.tokenizer_type == 'BertWordPieceCase':
assert args.vocab_file is not None
tokenizer = _BertWordPieceTokenizer(
vocab_file=args.vocab_file, lower_case=False, vocab_extra_ids=args.vocab_extra_ids
)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.vocab_file is not None
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _SentencePieceTokenizer(
args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids
)
elif args.tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'HuggingFaceTokenizer':
tokenizer = _HuggingFaceTokenizer(args.tokenizer_model, **kwargs)
elif args.tokenizer_type == 'Llama2Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama2Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'TikTokenizer':
assert args.tokenizer_model is not None
assert args.tiktoken_pattern is not None
assert args.tiktoken_pattern in {"v1", "v2"}
pattern = PATTERN_TIKTOKEN if args.tiktoken_pattern == "v1" else PATTERN_TIKTOKEN_V2
tokenizer = CustomTikTokenizer(
path=args.tokenizer_model,
pattern=pattern,
vocab_size=args.vocab_size,
num_special_tokens=args.tiktoken_num_special_tokens,
special_tokens=args.tiktoken_special_tokens,
)
elif args.tokenizer_type == 'NullTokenizer':
assert args.vocab_size is not None
tokenizer = _NullTokenizer(args.vocab_size)
elif args.tokenizer_type == "MultimodalTokenizer":
try:
import transformers
except ImportError:
raise ImportError(
"MultimodalTokenizer currently requires transformers library to be installed"
)
kwargs = dict()
if args.tokenizer_prompt_format == "nvlm-yi-34b":
kwargs = {
"from_slow": True,
"legacy": False,
"add_bos_token": True,
}
# Currently, only HuggingFace tokenizers are supported.
underlying_tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=args.tokenizer_model, **kwargs
)
tokenizer = MultimodalTokenizer(
underlying_tokenizer,
args.tokenizer_prompt_format,
args.special_tokens,
args.image_tag_type,
)
elif args.tokenizer_type == "DeepSeekV2Tokenizer":
tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size)
args.padded_vocab_size = tokenizer.vocab_size
else:
raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type))
# Add vocab size (if not already set from a checkpoint).
if getattr(args, "padded_vocab_size", None) is None:
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
return tokenizer
class _DeepSeekV2Tokenizer(MegatronTokenizer):
def __init__(self, tokenizer_path, extra_vocab_size):
super().__init__(tokenizer_path)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
padding_side="right",
trust_remote_code=True
)
self.extra_vocab_size = extra_vocab_size
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
try:
test_conversation = [
{'role': 'user', 'content': 'hello world'}
]
self.apply_chat_template(test_conversation)
except Exception:
# the default chat_template is invalid, assume user will not do SFT
self.tokenizer.chat_template = None
def __call__(self, text, return_tensors=None,
padding=None, max_length=None, truncation=None, add_special_tokens=None):
return self.tokenizer(text, return_tensors=return_tensors, padding=padding,
max_length=max_length, truncation=truncation, add_special_tokens=add_special_tokens)
def apply_chat_template(self, conversations, tokenize:bool=True, **kwargs):
return self.tokenizer.apply_chat_template(conversations, tokenize=tokenize, **kwargs)
@property
def vocab_size(self):
return len(self.tokenizer) + self.extra_vocab_size - 2
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.tokenizer.eos_token_id
@property
def eos_token(self):
return self.tokenizer.eos_token
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id
import torch
from megatron.core import mpu
from megatron.training import get_args
def get_batch_on_this_tp_rank(data_iterator):
args = get_args()
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids'])
else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
_broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids)
else:
position_ids = None
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import os
import torch
from functools import partial
from contextlib import nullcontext
import inspect
from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.utils import StragglerDetector
from megatron.core.transformer.spec_utils import import_module
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
get_blend_and_blend_per_split,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from dcu_megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron import megatron_adaptor
stimer = StragglerDetector()
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Returns:
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history:
torch.cuda.memory._record_memory_history(True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
# record stack information for the trace events
trace_alloc_record_context=True)
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
print_rank_0(f"config: {config}")
if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te)
else:
# Define the decoder layer spec
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
build_model_context = nullcontext
build_model_context_args = {}
if args.fp8_param_gather:
try:
from transformer_engine.pytorch import fp8_model_init
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
# Check if fp8_model_init supports preserve_high_precision_init_val
if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
build_model_context_args["preserve_high_precision_init_val"] = True
except:
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the mtp layer spec
if isinstance(transformer_layer_spec, TransformerBlockSubmodules):
mtp_transformer_layer_spec = transformer_layer_spec.layer_specs[-1]
else:
mtp_transformer_layer_spec = transformer_layer_spec
mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
with build_model_context(**build_model_context_args):
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_spec=mtp_spec
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model)
return model
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
Returns:
the loss scalar for this micro-batch
the number of non-padded tokens in this microbatch
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""
args = get_args()
losses = output_tensor.float()
if args.num_nextn_predict_layers > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
# Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine()
if args.check_for_nan_in_loss_and_grad:
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isnan,
message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
)
# Reduce loss for logging.
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0] * args.context_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
with stimer:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def is_dataset_built_on_rank():
return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
tokenizer = get_tokenizer()
# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length + args.num_nextn_predict_layers,
blend=blend,
blend_per_split=blend_per_split,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
mmap_bin_files=args.mmap_bin_files,
tokenizer=tokenizer,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path,
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()
config = core_gpt_dataset_config_from_args(args)
if args.mock_data:
dataset_type = MockGPTDataset
else:
dataset_type = GPTDataset
print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build()
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment