Commit f00f0256 authored by dongcl's avatar dongcl
Browse files

deepseek mtp bug解决

parent 627a739f
Pipeline #2462 passed with stage
...@@ -17,9 +17,11 @@ from megatron.core.models.common.language_module.language_module import Language ...@@ -17,9 +17,11 @@ from megatron.core.models.common.language_module.language_module import Language
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import ModelType from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor from megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from megatron.core.extensions.transformer_engine import TENorm
class GPTModel(LanguageModule): class GPTModel(LanguageModule):
...@@ -137,6 +139,7 @@ class GPTModel(LanguageModule): ...@@ -137,6 +139,7 @@ class GPTModel(LanguageModule):
spec=transformer_layer_spec, spec=transformer_layer_spec,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
num_nextn_predict_layers=num_nextn_predict_layers
) )
# Output # Output
......
...@@ -178,6 +178,7 @@ class TransformerBlock(MegatronModule): ...@@ -178,6 +178,7 @@ class TransformerBlock(MegatronModule):
post_layer_norm: bool = True, post_layer_norm: bool = True,
pre_process: bool = True, pre_process: bool = True,
post_process: bool = True, post_process: bool = True,
num_nextn_predict_layers: int = 0
): ):
super().__init__(config=config) super().__init__(config=config)
...@@ -185,6 +186,7 @@ class TransformerBlock(MegatronModule): ...@@ -185,6 +186,7 @@ class TransformerBlock(MegatronModule):
self.post_layer_norm = post_layer_norm self.post_layer_norm = post_layer_norm
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.num_nextn_predict_layers = num_nextn_predict_layers
# Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers). # Dictionary to store CUDA graphs. Number of items in the dictionary = len(self.layers).
# Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the # Item `i` in the dictionary is a list of `N` CUDA graphs for layer 'i' where N is the
# number of microbatches. Multiple CUDA graphs per layer is required to support # number of microbatches. Multiple CUDA graphs per layer is required to support
...@@ -246,7 +248,7 @@ class TransformerBlock(MegatronModule): ...@@ -246,7 +248,7 @@ class TransformerBlock(MegatronModule):
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline # In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior # self.post_process and self.post_layer_norm guide this behavior
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
move_final_norm_out_of_block = args.num_nextn_predict_layers > 0 move_final_norm_out_of_block = self.num_nextn_predict_layers > 0
if self.submodules.layer_norm and self.post_process and self.post_layer_norm and not move_final_norm_out_of_block: if self.submodules.layer_norm and self.post_process and self.post_layer_norm and not move_final_norm_out_of_block:
self.final_layernorm = build_module( self.final_layernorm = build_module(
self.submodules.layer_norm, self.submodules.layer_norm,
......
...@@ -388,105 +388,104 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -388,105 +388,104 @@ def get_batch_on_this_tp_rank(data_iterator):
args = get_args() args = get_args()
def _broadcast(item): def _broadcast(item):
if item is not None: if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
batch = { batch = {
'tokens': data["tokens"].cuda(non_blocking = True), 'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True) 'position_ids': data["position_ids"].cuda(non_blocking = True)
} }
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels']) _broadcast(batch['labels'])
_broadcast(batch['loss_mask']) _broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage(): elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers: if args.num_nextn_predict_layers:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels']) _broadcast(batch['labels'])
_broadcast(batch['loss_mask']) _broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers: if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
else: else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), dtype = torch.int64,
dtype = torch.int64, device = torch.cuda.current_device())
device = torch.cuda.current_device()) labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), dtype = torch.int64,
dtype = torch.int64, device = torch.cuda.current_device())
device = torch.cuda.current_device()) loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), dtype = torch.float32,
dtype = torch.float32, device = torch.cuda.current_device())
device = torch.cuda.current_device()) if args.create_attention_mask_in_dataloader:
if args.create_attention_mask_in_dataloader: attention_mask=torch.empty(
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers, (args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool, args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device() device = torch.cuda.current_device()
) )
else: else:
attention_mask=None attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64, dtype = torch.int64,
device = torch.cuda.current_device()) device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
_broadcast(tokens) _broadcast(tokens)
_broadcast(labels) _broadcast(labels)
_broadcast(loss_mask) _broadcast(loss_mask)
_broadcast(attention_mask) _broadcast(attention_mask)
_broadcast(position_ids) _broadcast(position_ids)
elif mpu.is_pipeline_first_stage(): elif mpu.is_pipeline_first_stage():
labels=None labels=None
loss_mask=None loss_mask=None
_broadcast(tokens) _broadcast(tokens)
_broadcast(attention_mask) _broadcast(attention_mask)
_broadcast(position_ids) _broadcast(position_ids)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers: if args.num_nextn_predict_layers:
_broadcast(tokens) _broadcast(tokens)
else: else:
tokens = None tokens = None
_broadcast(labels) _broadcast(labels)
_broadcast(loss_mask) _broadcast(loss_mask)
_broadcast(attention_mask) _broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers: if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids) _broadcast(position_ids)
else: else:
position_ids = None position_ids = None
batch = { batch = {
'tokens': tokens, 'tokens': tokens,
'labels': labels, 'labels': labels,
'loss_mask': loss_mask, 'loss_mask': loss_mask,
'attention_mask': attention_mask, 'attention_mask': attention_mask,
'position_ids': position_ids 'position_ids': position_ids
} }
return batch return batch
......
...@@ -37,6 +37,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ...@@ -37,6 +37,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
) )
from megatron.core.transformer.mtp.mtp_spec import get_mtp_spec from megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from megatron.core.utils import tensor_slide
import torch._dynamo import torch._dynamo
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
...@@ -190,6 +191,8 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -190,6 +191,8 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args = get_args() args = get_args()
losses = output_tensor.float() losses = output_tensor.float()
if args.num_nextn_predict_layers > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum() total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment