Commit 1f7b14ab authored by sdwldchl's avatar sdwldchl
Browse files

rewrite mtp

parent 89d29a02
......@@ -5,6 +5,8 @@ import types
import argparse
import torch
from .adaptor_arguments import get_adaptor_args
class MegatronAdaptation:
"""
......@@ -21,6 +23,15 @@ class MegatronAdaptation:
for adaptation in [CoreAdaptation(), LegacyAdaptation()]:
adaptation.execute()
MegatronAdaptation.apply()
# from .patch_utils import MegatronPatchesManager
# args = get_adaptor_args()
# for feature in FEATURES_LIST:
# if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
# feature.register_patches(MegatronPatchesManager, args)
# MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute()
@classmethod
......@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
self.patch_miscellaneous()
def patch_core_distributed(self):
# Mtp share embedding
# mtp share embedding
from ..core.distributed.finalize_model_grads import _allreduce_word_embedding_grads
MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
_allreduce_word_embedding_grads)
def patch_core_models(self):
from ..core.models.common.embeddings.language_model_embedding import (
language_model_embedding_forward,
language_model_embedding_init_func
)
from ..core.models.gpt.gpt_model import (
gpt_model_forward,
gpt_model_init,
shared_embedding_or_output_weight,
)
from ..core.models.common.language_module.language_module import (
setup_embeddings_and_output_layer,
tie_embeddings_and_output_weights_state_dict
tie_embeddings_and_output_weights_state_dict,
)
from ..core.models.gpt.gpt_model import GPTModel
from ..training.utils import get_batch_on_this_tp_rank
# Embedding
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__',
language_model_embedding_init_func)
MegatronAdaptation.register(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward',
language_model_embedding_forward)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
# LanguageModule
MegatronAdaptation.register(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer',
setup_embeddings_and_output_layer)
......@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict',
tie_embeddings_and_output_weights_state_dict)
MegatronAdaptation.register(
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight',
shared_embedding_or_output_weight)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init)
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank)
# GPT Model
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel)
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation.register('megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_init_wrapper)
......@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
vocab_parallel_embedding_forward)
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__',
vocab_parallel_embedding_init_wrapper,
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
# VocabParallelCrossEntropy
......@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec",
get_gpt_layer_with_flux_spec)
def patch_pipeline_parallel(self):
from ..core.pipeline_parallel.schedules import forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper,
apply_wrapper=True)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
......@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper,
apply_wrapper=True)
# ParallelAttention
MegatronAdaptation.register('megatron.legacy.model.transformer.ParallelAttention.__init__',
parallel_attention_init_wrapper,
apply_wrapper=True)
......
......@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_or_cls_name, new_func_or_cls=None, force_patch=False, create_dummy=False):
def register_patch(
orig_func_or_cls_name,
new_func_or_cls=None,
force_patch=False,
create_dummy=False,
apply_wrapper=False,
remove_origin_wrappers=False
):
if orig_func_or_cls_name not in MegatronPatchesManager.patches_info:
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(orig_func_or_cls_name, new_func_or_cls, create_dummy)
MegatronPatchesManager.patches_info[orig_func_or_cls_name] = Patch(
orig_func_or_cls_name,
new_func_or_cls,
create_dummy,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
else:
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(new_func_or_cls, force_patch)
MegatronPatchesManager.patches_info.get(orig_func_or_cls_name).set_patch_func(
new_func_or_cls,
force_patch,
apply_wrapper=apply_wrapper,
remove_origin_wrappers=remove_origin_wrappers
)
@staticmethod
def apply_patches():
......
......@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights or getattr(config, 'num_nextn_predict_layers', 0):
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
......
......@@ -4,6 +4,7 @@ import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint
......@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
if not self.share_embeddings_and_output_weights and not getattr(
self.config, 'num_nextn_predict_layers', 0
self.config, 'mtp_num_layers', 0
):
return
......@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:
self.shared_embedding_or_output_weight().shared_embedding = True
if self.post_process and not self.pre_process:
if (self.post_process or getattr(self, 'mtp_process', False)) and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
# set weights of the duplicated embedding to 0 here,
# then copy weights from pre processing stage using all_reduce below.
weight = self.shared_embedding_or_output_weight()
weight.data.fill_(0)
weight.shared = True
......@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
if self.post_process and getattr(self, 'num_nextn_predict_layers', False):
if getattr(self, 'mtp_process', False):
# No output layer
assert output_layer_weight_key not in sharded_state_dict, sharded_state_dict.keys()
return
......
import warnings
from typing import Optional
from typing import Optional, Union
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
......@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from dcu_megatron.core.tensor_parallel.layers import FluxColumnParallelLinear, FluxRowParallelLinear
from megatron.core.utils import is_te_min_version
try:
......@@ -36,6 +36,55 @@ try:
except ImportError:
warnings.warn('Apex is not installed.')
from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear
)
from dcu_megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlockSubmodules,
get_mtp_layer_offset,
get_mtp_layer_spec,
get_mtp_num_layers_to_build,
)
def get_gpt_mtp_block_spec(
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
use_transformer_engine: bool,
) -> MultiTokenPredictionBlockSubmodules:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build = get_mtp_num_layers_to_build(config)
if num_layers_to_build == 0:
return None
if isinstance(spec, TransformerBlockSubmodules):
# get the spec for the last layer of decoder block
transformer_layer_spec = spec.layer_specs[-1]
elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer:
transformer_layer_spec = spec
else:
raise ValueError(f"Invalid spec: {spec}")
mtp_layer_spec = get_mtp_layer_spec(
transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine
)
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers
offset = get_mtp_layer_offset(config)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs = mtp_layer_specs[offset : offset + num_layers_to_build]
if len(mtp_layer_specs) > 0:
assert (
len(mtp_layer_specs) == config.mtp_num_layers
), +f"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs)
else:
mtp_block_spec = None
return mtp_block_spec
def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None,
......
This diff is collapsed.
import torch
from functools import wraps
from dcu_megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
def forward_step_wrapper(fn):
@wraps(fn)
def wrapper(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs,
):
output, num_tokens = fn(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs
)
if not isinstance(input_tensor, list):
# unwrap_output_tensor True
output_tensor = output
else:
output_tensor = output[0]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output, num_tokens
return wrapper
\ No newline at end of file
from .layers import (
FluxColumnParallelLinear,
FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init_wrapper,
)
\ No newline at end of file
import os
import copy
import socket
import warnings
from functools import wraps
from typing import Callable, List, Optional
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
try:
import flux
from dcu_megatron.core.utils import is_flux_min_version
except ImportError:
raise ImportError("flux is NOT installed")
try:
import flux
except ImportError:
raise ImportError("flux is NOT installed")
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.training import print_rank_0
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import (
is_torch_min_version,
prepare_input_tensors_for_wgrad_compute
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
VocabParallelEmbedding,
)
from megatron.core.utils import prepare_input_tensors_for_wgrad_compute
from megatron.core.tensor_parallel.mappings import (
_reduce,
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
_reduce_scatter_along_first_dim,
_gather_along_first_dim,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.tensor_parallel.mappings import _reduce
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
......@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd,
custom_bwd,
dist_all_gather_func,
linear_with_frozen_weight,
linear_with_grad_accumulation_and_async_allreduce
)
from dcu_megatron.core.utils import is_flux_min_version
_grad_accum_fusion_available = True
try:
......@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available = False
def vocab_parallel_embedding_init_wrapper(fn):
@wraps(fn)
def wrapper(self,
*args,
skip_weight_param_allocation: bool = False,
**kwargs
):
if (
skip_weight_param_allocation
and "config" in kwargs
and hasattr(kwargs["config"], "perform_initialization")
):
config = copy.deepcopy(kwargs["config"])
config.perform_initialization = False
kwargs["config"] = config
fn(self, *args, **kwargs)
if skip_weight_param_allocation:
self.weight = None
return wrapper
@torch.compile(mode='max-autotune-no-cudagraphs')
def vocab_parallel_embedding_forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def get_tensor_model_parallel_node_size(group=None):
""" 获取节点数
"""
......
This diff is collapsed.
......@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "num_nextn_predict_layers", 0) > 0:
if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
......
from typing import Optional
from functools import wraps
from dataclasses import dataclass
from megatron.training import get_args
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
fn(self)
args = get_args()
"""Number of Multi-Token Prediction (MTP) Layers."""
self.mtp_num_layers = args.mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor
##################
# flux
##################
self.flux_transpose_weight = args.flux_transpose_weight
return wrapper
@dataclass
class ExtraTransformerConfig:
##################
# multi-token prediction
##################
num_nextn_predict_layers: int = 0
"""The number of multi-token prediction layers"""
mtp_loss_scale: float = 0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm: bool = False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer: bool = False
"""Whether to recompute mtp layer"""
mtp_num_layers: Optional[int] = None
"""Number of Multi-Token Prediction (MTP) Layers."""
share_mtp_embedding_and_output_weight: bool = False
"""share embedding and output weight with mtp layer."""
mtp_loss_scaling_factor: Optional[float] = None
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
##################
# flux
......
......@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale')
group.add_argument('--recompute-mtp-norm', action='store_true', default=False,
help='Multi-Token prediction recompute norm')
group.add_argument('--recompute-mtp-layer', action='store_true', default=False,
help='Multi-Token prediction recompute layer')
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.')
group.add_argument('--mtp-num-layers', type=int, default=None,
help='Number of Multi-Token Prediction (MTP) Layers.'
'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.')
group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.3,
help='Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser
......
......@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args = get_args()
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids)
else:
position_ids = None
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
else:
tokens=None
position_ids=None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch
......@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
from dcu_megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
from dcu_megatron import megatron_adaptor
......@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the mtp layer spec
if isinstance(transformer_layer_spec, TransformerBlockSubmodules):
mtp_transformer_layer_spec = transformer_layer_spec.layer_specs[-1]
else:
mtp_transformer_layer_spec = transformer_layer_spec
mtp_block_spec = None
if args.mtp_num_layers is not None:
from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
......@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model)
......@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args = get_args()
losses = output_tensor.float()
if getattr(args, "num_nextn_predict_layers", 0) > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
......@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator').stop()
with stimer:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, loss_mask=loss_mask)
return output_tensor, partial(loss_func, loss_mask)
......@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length + getattr(args, "num_nextn_predict_layers", 0),
sequence_length=args.seq_length,
blend=blend,
blend_per_split=blend_per_split,
split=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment