Commit 770fa304 authored by dongcl's avatar dongcl
Browse files

修改mtp

parent 8096abd4
...@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True): ...@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if check_equality: if check_equality:
return get_flux_version() >= PkgVersion(version) return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version) return get_flux_version() > PkgVersion(version)
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
...@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser): ...@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def _add_mtp_args(parser): def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction') group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num') group.add_argument('--mtp-num-layers', type=int, default=None,
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale') help='Number of Multi-Token Prediction (MTP) Layers.'
group.add_argument('--recompute-mtp-norm', action='store_true', default=False, 'MTP extends the prediction scope to multiple future tokens at each position.'
help='Multi-Token prediction recompute norm') 'This MTP implementation sequentially predict additional tokens '
group.add_argument('--recompute-mtp-layer', action='store_true', default=False, 'by using D sequential modules to predict D additional tokens.')
help='Multi-Token prediction recompute layer') group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.1,
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, help='Scaling factor of Multi-Token Prediction (MTP) loss. '
help='Main model share embedding and output weight with mtp layer.') 'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.')
return parser return parser
......
...@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args = get_args() args = get_args()
def _broadcast(item): def _broadcast(item):
if item is not None: if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
batch = { batch = {
'tokens': data["tokens"].cuda(non_blocking = True), 'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True), 'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True), 'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True), 'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True) 'position_ids': data["position_ids"].cuda(non_blocking = True)
} }
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels']) _broadcast(batch['labels'])
_broadcast(batch['loss_mask']) _broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage(): elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['attention_mask']) _broadcast(batch['attention_mask'])
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers: # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens']) _broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids']) _broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
else: else:
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64, tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
device = torch.cuda.current_device()) labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
dtype = torch.int64, if args.create_attention_mask_in_dataloader:
device = torch.cuda.current_device()) attention_mask=torch.empty(
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), (args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
) )
else: else:
attention_mask=None attention_mask=None
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
dtype = torch.int64,
device = torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
if args.pipeline_model_parallel_size == 1: _broadcast(labels)
_broadcast(tokens) _broadcast(loss_mask)
_broadcast(labels) _broadcast(attention_mask)
_broadcast(loss_mask) _broadcast(position_ids)
_broadcast(attention_mask)
_broadcast(position_ids) elif mpu.is_pipeline_first_stage():
labels=None
elif mpu.is_pipeline_first_stage(): loss_mask=None
labels=None
loss_mask=None _broadcast(tokens)
_broadcast(attention_mask)
_broadcast(tokens) _broadcast(position_ids)
_broadcast(attention_mask)
_broadcast(position_ids) elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
elif mpu.is_pipeline_last_stage(): # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
if args.num_nextn_predict_layers: # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens) _broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids) _broadcast(position_ids)
else: else:
position_ids = None tokens=None
position_ids=None
batch = {
'tokens': tokens, _broadcast(labels)
'labels': labels, _broadcast(loss_mask)
'loss_mask': loss_mask, _broadcast(attention_mask)
'attention_mask': attention_mask,
'position_ids': position_ids batch = {
} 'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch return batch
...@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ...@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_spec,
) )
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
from dcu_megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from dcu_megatron.core.utils import tensor_slide
from dcu_megatron import megatron_adaptor from dcu_megatron import megatron_adaptor
...@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.") raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the mtp layer spec # Define the mtp layer spec
if isinstance(transformer_layer_spec, TransformerBlockSubmodules): mtp_block_spec = None
mtp_transformer_layer_spec = transformer_layer_spec.layer_specs[-1] if args.mtp_num_layers is not None:
else: from dcu_megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
mtp_transformer_layer_spec = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
with build_model_context(**build_model_context_args): with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel( model = GPTModel(
config=config, config=config,
transformer_layer_spec=transformer_layer_spec, transformer_layer_spec=transformer_layer_spec,
...@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type, position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent, rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
) )
# model = torch.compile(model,mode='max-autotune-no-cudagraphs') # model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model) print_rank_0(model)
...@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args = get_args() args = get_args()
losses = output_tensor.float() losses = output_tensor.float()
if getattr(args, "num_nextn_predict_layers", 0) > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum() total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
...@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel): ...@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator').stop() timers('batch-generator').stop()
with stimer: with stimer:
output_tensor = model(tokens, position_ids, attention_mask, if args.use_legacy_models:
labels=labels) output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, loss_mask=loss_mask)
return output_tensor, partial(loss_func, loss_mask) return output_tensor, partial(loss_func, loss_mask)
...@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args): ...@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig( return GPTDatasetConfig(
random_seed=args.seed, random_seed=args.seed,
sequence_length=args.seq_length + getattr(args, "num_nextn_predict_layers", 0), sequence_length=args.seq_length,
blend=blend, blend=blend,
blend_per_split=blend_per_split, blend_per_split=blend_per_split,
split=args.split, split=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment