Commit 1f7b14ab authored by sdwldchl's avatar sdwldchl
Browse files

rewrite mtp

parent 89d29a02
...@@ -5,6 +5,8 @@ import types ...@@ -5,6 +5,8 @@ import types
import argparse import argparse
import torch import torch
from .adaptor_arguments import get_adaptor_args
class MegatronAdaptation: class MegatronAdaptation:
""" """
...@@ -21,6 +23,15 @@ class MegatronAdaptation: ...@@ -21,6 +23,15 @@ class MegatronAdaptation:
for adaptation in [CoreAdaptation(), LegacyAdaptation()]: for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute() adaptation.execute()
MegatronAdaptation.apply() MegatronAdaptation.apply()
# from .patch_utils import MegatronPatchesManager
# args = get_adaptor_args()
# for feature in FEATURES_LIST:
# if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
# feature.register_patches(MegatronPatchesManager, args)
# MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute() # MegatronAdaptation.post_execute()
@classmethod @classmethod
...@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_miscellaneous() self.patch_miscellaneous()
def patch_core_distributed(self): def patch_core_distributed(self):
# Mtp share embedding # mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads', MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads) _allreduce_word_embedding_grads)
def patch_core_models(self): def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init,
shared_embedding_or_output_weight,
)
from ..core.models.common.language_module.language_module import ( from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer, setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict tie_embeddings_and_output_weights_state_dict,
) )
from ..core.models.gpt.gpt_model import GPTModel
from ..training.utils import get_batch_on_this_tp_rank from ..training.utils import get_batch_on_this_tp_rank
# Embedding # LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register( MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer', 'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer) setup_embeddings_and_output_layer)
...@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict', 'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict) tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register( MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight) # GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
def patch_core_transformers(self): def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block # Transformer block. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__', MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper) transformer_block_init_wrapper)
...@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self): def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding # VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward) torch.compile(mode='max-autotune-no-cudagraphs'),
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
# VocabParallelCrossEntropy # VocabParallelCrossEntropy
...@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec", MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec) get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
def patch_training(self): def patch_training(self):
from ..training.tokenizer import build_tokenizer from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
...@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC): ...@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper, parallel_mlp_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
# ParallelAttention
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__', MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper, parallel_attention_init_wrapper,
apply_wrapper=True) apply_wrapper=True)
......
...@@ -148,11 +148,29 @@ class MegatronPatchesManager: ...@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info = {} patches_info = {}
@staticmethod @staticmethod
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False): def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info: if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy) MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else: else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch) MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod @staticmethod
def apply_patches(): def apply_patches():
......
...@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf ...@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0] model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights or getattr(config, 'num_nextn_predict_layers', 0):
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight() weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr) orig_grad = getattr(weight, grad_attr)
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
...@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None: ...@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
# So we need to copy embedding weights from pre processing stage as initial parameters # So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases. # in these cases.
if not self.share_embeddings_and_output_weights and not getattr( if not self.share_embeddings_and_output_weights and not getattr(
self.config, 'num_nextn_predict_layers', 0 self.config, 'mtp_num_layers', 0
): ):
return return
...@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None: ...@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:
self.shared_embedding_or_output_weight().shared_embedding = True self.shared_embedding_or_output_weight().shared_embedding = True
if self.post_process and not self.pre_process: if (self.post_process or getattr(self, 'mtp_process', False)) and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage() assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first # set weights of the duplicated embedding to 0 here,
# stage's weights using all_reduce below. # then copy weights from pre processing stage using all_reduce below.
weight = self.shared_embedding_or_output_weight() weight = self.shared_embedding_or_output_weight()
weight.data.fill_(0) weight.data.fill_(0)
weight.shared = True weight.shared = True
...@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict( ...@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True, # layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have # the shared weights will be stored in embedding layer, and output layer will not have
# any weight. # any weight.
if self.post_process and getattr(self, 'num_nextn_predict_layers', False): if getattr(self, 'mtp_process', False):
# No output layer # No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys() assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return return
......
import warnings import warnings
from typing import Optional from typing import Optional, Union
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
...@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import ( ...@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules, MLASelfAttentionSubmodules,
) )
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import ( from megatron.core.transformer.transformer_layer import (
TransformerLayer, TransformerLayer,
TransformerLayerSubmodules, TransformerLayerSubmodules,
) )
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version from megatron.core.utils import is_te_min_version
try: try:
...@@ -36,6 +36,55 @@ try: ...@@ -36,6 +36,55 @@ try:
except ImportError: except ImportError:
warnings.warn('Apex is not installed.') warnings.warn('Apex is not installed.')
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_mtp_block_spec(
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
use_transformer_engine: bool,
) -> MultiTokenPredictionBlockSubmodules:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build = get_mtp_num_layers_to_build(config)
if num_layers_to_build == 0:
return None
if isinstance(spec, TransformerBlockSubmodules):
# get the spec for the last layer of decoder block
transformer_layer_spec = spec.layer_specs[-1]
elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer:
transformer_layer_spec = spec
else:
raise ValueError(f"Invalid spec: {spec}")
mtp_layer_spec = get_mtp_layer_spec(
transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine
)
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers
offset = get_mtp_layer_offset(config)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build]
if len(mtp_layer_specs) > 0:
assert (
len(mtp_layer_specs) == config.mtp_num_layers
), +f"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs)
else:
mtp_block_spec = None
return mtp_block_spec
def get_gpt_layer_with_flux_spec( def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None, num_experts: Optional[int] = None,
......
This diff is collapsed.
import torch
from functools import wraps
from dcu_megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
def forward_step_wrapper(fn):
@wraps(fn)
def wrapper(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs,
):
output, num_tokens = fn(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs
)
if not isinstance(input_tensor, list):
# unwrap_output_tensor True
output_tensor = output
else:
output_tensor = output[0]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output, num_tokens
return wrapper
\ No newline at end of file
from .layers import ( from .layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear, FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init_wrapper,
) )
\ No newline at end of file
import os import os
import copy
import socket import socket
import warnings import warnings
from functools import wraps
from typing import Callable, List, Optional from typing import Callable, List, Optional
if int(os.getenv("USE_FLUX_OVERLAP", "0")): try:
try: import flux
import flux except ImportError:
from dcu_megatron.core.utils import is_flux_min_version raise ImportError("flux is NOT installed")
except ImportError:
raise ImportError("flux is NOT installed")
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_global_memory_buffer, get_global_memory_buffer,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from megatron.core.utils import ( from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
is_torch_min_version,
prepare_input_tensors_for_wgrad_compute
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.core.tensor_parallel.mappings import ( from megatron.core.tensor_parallel.mappings import (
_reduce,
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
) )
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import ( from megatron.core.tensor_parallel import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
...@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd, custom_fwd,
custom_bwd, custom_bwd,
dist_all_gather_func, dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
) )
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True _grad_accum_fusion_available = True
try: try:
...@@ -61,74 +41,6 @@ except ImportError: ...@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available = False _grad_accum_fusion_available = False
def vocab_parallel_embedding_init_wrapper(fn):
@wraps(fn)
def wrapper(self,
*args,
skip_weight_param_allocation: bool = False,
**kwargs
):
if (
skip_weight_param_allocation
and "config" in kwargs
and hasattr(kwargs["config"], "perform_initialization")
):
config = copy.deepcopy(kwargs["config"])
config.perform_initialization = False
kwargs["config"] = config
fn(self, *args, **kwargs)
if skip_weight_param_allocation:
self.weight = None
return wrapper
@torch.compile(mode='max-autotune-no-cudagraphs')
def vocab_parallel_embedding_forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def get_tensor_model_parallel_node_size(group=None): def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数 """ 获取节点数
""" """
......
This diff is collapsed.
...@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn): ...@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config'] config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "num_nextn_predict_layers", 0) > 0: if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None self.final_layernorm = None
......
from typing import Optional
from functools import wraps
from dataclasses import dataclass from dataclasses import dataclass
from megatron.training import get_args
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
fn(self)
args = get_args()
"""Number of Multi-Token Prediction (MTP) Layers."""
self.mtp_num_layers = args.mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor
##################
# flux
##################
self.flux_transpose_weight = args.flux_transpose_weight
return wrapper
@dataclass @dataclass
class ExtraTransformerConfig: class ExtraTransformerConfig:
################## ##################
# multi-token prediction # multi-token prediction
################## ##################
num_nextn_predict_layers: int = 0 mtp_num_layers: Optional[int] = None
"""The number of multi-token prediction layers""" """Number of Multi-Token Prediction (MTP) Layers."""
mtp_loss_scale: float = 0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm: bool = False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer: bool = False
"""Whether to recompute mtp layer"""
share_mtp_embedding_and_output_weight: bool = False mtp_loss_scaling_factor: Optional[float] = None
"""share embedding and output weight with mtp layer.""" """Weighting factor of Multi-Token Prediction (MTP) loss."""
################## ##################
# flux # flux
......
...@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser): ...@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def _add_mtp_args(parser): def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction') group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num') group.add_argument('--mtp-num-layers', type=int, default=None,
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale') help='Number of Multi-Token Prediction (MTP) Layers.'
group.add_argument('--recompute-mtp-norm', action='store_true', default=False, 'MTP extends the prediction scope to multiple future tokens at each position.'
help='Multi-Token prediction recompute norm') 'This MTP implementation sequentially predict additional tokens '
group.add_argument('--recompute-mtp-layer', action='store_true', default=False, 'by using D sequential modules to predict D additional tokens.')
help='Multi-Token prediction recompute layer') group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, help='Scaling factor of Multi-Token Prediction (MTP) loss. '
help='Main model share embedding and output weight with mtp layer.') 'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser return parser
......
...@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args = get_args() args = get_args()
def _broadcast(item): def _broadcast(item):
if item is not None: if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
batch = { batch = {
'tokens': data["tokens"].cuda(non_blocking = True), 'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True) 'position_ids': data["position_ids"].cuda(non_blocking = True)
} }
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels']) _broadcast(batch['labels'])
_broadcast(batch['loss_mask']) _broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage(): elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers: # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
else: else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64, tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
device = torch.cuda.current_device()) labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
dtype = torch.int64, if args.create_attention_mask_in_dataloader:
device = torch.cuda.current_device()) attention_mask=torch.empty(
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), (args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
) )
else: else:
attention_mask=None attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
dtype = torch.int64,
device = torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
if args.pipeline_model_parallel_size == 1: _broadcast(labels)
_broadcast(tokens) _broadcast(loss_mask)
_broadcast(labels) _broadcast(attention_mask)
_broadcast(loss_mask) _broadcast(position_ids)
_broadcast(attention_mask)
_broadcast(position_ids) elif mpu.is_pipeline_first_stage():
labels=None
elif mpu.is_pipeline_first_stage(): loss_mask=None
labels=None
loss_mask=None _broadcast(tokens)
_broadcast(attention_mask)
_broadcast(tokens) _broadcast(position_ids)
_broadcast(attention_mask)
_broadcast(position_ids) elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
elif mpu.is_pipeline_last_stage(): # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
if args.num_nextn_predict_layers: # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens) _broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids) _broadcast(position_ids)
else: else:
position_ids = None tokens=None
position_ids=None
batch = {
'tokens': tokens, _broadcast(labels)
'labels': labels, _broadcast(loss_mask)
'loss_mask': loss_mask, _broadcast(attention_mask)
'attention_mask': attention_mask,
'position_ids': position_ids batch = {
} 'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch return batch
...@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ...@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_spec,
) )
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
from dcu_megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron import megatron_adaptor from dcu_megatron import megatron_adaptor
...@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the mtp layer spec # Define the mtp layer spec
if isinstance(transformer_layer_spec, TransformerBlockSubmodules): mtp_block_spec = None
mtp_transformer_layer_spec = transformer_layer_spec.layer_specs[-1] if args.mtp_num_layers is not None:
else: from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
mtp_transformer_layer_spec = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel( model = GPTModel(
config=config, config=config,
transformer_layer_spec=transformer_layer_spec, transformer_layer_spec=transformer_layer_spec,
...@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type, position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent, rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
) )
# model = torch.compile(model,mode='max-autotune-no-cudagraphs') # model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model) print_rank_0(model)
...@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args = get_args() args = get_args()
losses = output_tensor.float() losses = output_tensor.float()
if getattr(args, "num_nextn_predict_layers", 0) > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum() total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
...@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel): ...@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator').stop() timers('batch-generator').stop()
with stimer: with stimer:
output_tensor = model(tokens, position_ids, attention_mask, if args.use_legacy_models:
labels=labels) output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, loss_mask=loss_mask)
return output_tensor, partial(loss_func, loss_mask) return output_tensor, partial(loss_func, loss_mask)
...@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args): ...@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig( return GPTDatasetConfig(
random_seed=args.seed, random_seed=args.seed,
sequence_length=args.seq_length + getattr(args, "num_nextn_predict_layers", 0), sequence_length=args.seq_length,
blend=blend, blend=blend,
blend_per_split=blend_per_split, blend_per_split=blend_per_split,
split=args.split, split=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment