Commit 627a739f authored by dongcl's avatar dongcl
Browse files

支持deepseek mtp和无辅助损失负载均衡

parent 1310cbf8
Pipeline #2444 passed with stage
......@@ -13,6 +13,7 @@ except ImportError:
HAVE_DTENSOR = False
from .. import parallel_state
from ..transformer.moe.moe_utils import get_updated_expert_bias
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
......@@ -135,6 +136,15 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
if hasattr(model_module,
"share_mtp_embedding_and_output_weight") and model_module.share_mtp_embedding_and_output_weight:
weight = model_module.shared_embedding_or_mtp_embedding_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
......@@ -184,9 +194,8 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
if param.requires_grad and (
getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
......@@ -209,6 +218,34 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
"""
tokens_per_expert_list = []
expert_bias_list = []
for model_chunk in model:
for module in get_attr_wrapped_model(model_chunk, 'modules')():
if hasattr(module, 'expert_bias'):
tokens_per_expert_list.append(module.local_tokens_per_expert)
expert_bias_list.append(module.expert_bias)
# For hybrid models with both MoE and Dense layers, this list can be empty.
if len(expert_bias_list) == 0:
return
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
)
for tokens_per_expert, expert_bias, updated_expert_bias in zip(
tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
):
tokens_per_expert.zero_()
expert_bias.copy_(updated_expert_bias)
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
......@@ -253,6 +290,9 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
......
......@@ -33,6 +33,7 @@ class LanguageModelEmbedding(MegatronModule):
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
num_tokentypes: int = 0,
scatter_to_sequence_parallel: bool = True,
skip_weight_param_allocation: bool = False
):
super().__init__(config=config)
......@@ -56,6 +57,7 @@ class LanguageModelEmbedding(MegatronModule):
init_method=self.config.init_method,
reduce_scatter_embeddings=self.reduce_scatter_embeddings,
config=self.config,
skip_weight_param_allocation=skip_weight_param_allocation
)
# Position embedding (serial).
......@@ -91,7 +93,7 @@ class LanguageModelEmbedding(MegatronModule):
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None) -> Tensor:
def forward(self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: int = None, weight: Tensor = None) -> Tensor:
"""Forward pass of the embedding module.
Args:
......@@ -99,11 +101,20 @@ class LanguageModelEmbedding(MegatronModule):
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
word_embeddings = self.word_embeddings(input_ids)
if weight is None:
if self.word_embeddings.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.word_embeddings.weight
word_embeddings = self.word_embeddings(input_ids, weight)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = word_embeddings + position_embeddings
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from collections import OrderedDict
from typing import Dict, Literal, Optional
from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.utils import tensor_slide
from megatron.core import InferenceParams, tensor_parallel, parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
......@@ -16,6 +19,7 @@ from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
class GPTModel(LanguageModule):
......@@ -75,6 +79,12 @@ class GPTModel(LanguageModule):
rope_scaling: bool = False,
scatter_embedding_sequence_parallel: bool = True,
seq_len_interpolation_factor: Optional[float] = None,
mtp_spec: ModuleSpec = None,
num_nextn_predict_layers: int = 0,
share_mtp_embedding_and_output_weight: bool = False,
recompute_mtp_norm: bool = False,
recompute_mtp_layer: bool = False,
mtp_loss_scale: float = 0.3
) -> None:
super().__init__(config=config)
......@@ -160,14 +170,140 @@ class GPTModel(LanguageModule):
grad_output_buffer=self.grad_output_buffer,
)
# add mtp
self.mtp_spec: ModuleSpec = mtp_spec
self.num_nextn_predict_layers = num_nextn_predict_layers
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_mtp_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.mtp_loss_scale = mtp_loss_scale
if self.post_process and self.training and self.num_nextn_predict_layers:
self.mtp_layers = torch.nn.ModuleList(
[
MultiTokenPredictor(
config,
self.mtp_spec.submodules,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
layer_number=i,
pre_process=self.pre_process,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=seq_len_interpolation_factor,
share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=self.recompute_mtp_norm,
recompute_mtp_layer=self.recompute_mtp_layer,
add_output_layer_bias=False
)
for i in range(self.num_nextn_predict_layers)
]
)
if self.post_process and self.num_nextn_predict_layers:
# move block main model final norms here
self.final_layernorm = build_module(
TENorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
if self.num_nextn_predict_layers:
self.setup_mtp_embeddings()
if has_config_logger_enabled(self.config):
log_config_to_disk(
self.config, self.state_dict(), prefix=f'{type(self).__name__}_init_ckpt'
)
def shared_embedding_or_mtp_embedding_weight(self) -> Tensor:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
"""
assert self.num_nextn_predict_layers > 0
if self.pre_process:
return self.embedding.word_embeddings.weight
elif self.post_process:
return self.mtp_layers[0].embedding.word_embeddings.weight
return None
def setup_mtp_embeddings(self):
"""
Share embedding layer in mtp layer.
"""
if self.pre_process:
self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
# Set `is_embedding_or_output_parameter` attribute.
for i in range(self.num_nextn_predict_layers):
if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None:
self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True
if not self.share_mtp_embedding_and_output_weight:
return
if self.pre_process and self.post_process:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self.shared_embedding_or_mtp_embedding_weight().zero_out_wgrad = True
return
if self.pre_process and not self.post_process:
assert parallel_state.is_pipeline_first_stage()
self.shared_embedding_or_mtp_embedding_weight().shared_embedding = True
if self.post_process and not self.pre_process:
assert not parallel_state.is_pipeline_first_stage()
for i in range(self.num_nextn_predict_layers):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0)
self.mtp_layers[i].embedding.word_embeddings.weight.shared = True
self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_mtp_embedding_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule.embedding_warning_printed = True
def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
......@@ -184,6 +320,48 @@ class GPTModel(LanguageModule):
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])
def slice_inputs(self, input_ids, labels, position_ids, attention_mask):
if self.num_nextn_predict_layers == 0:
return (
[input_ids],
[labels],
[position_ids],
[attention_mask],
)
return (
tensor_slide(input_ids, self.num_nextn_predict_layers),
tensor_slide(labels, self.num_nextn_predict_layers),
self.generate_nextn_position_ids(position_ids, self.num_nextn_predict_layers),
# not compatible with ppo attn_mask
tensor_slide(attention_mask, self.num_nextn_predict_layers, dims=[-2, -1]),
)
def generate_nextn_position_ids(self, tensor, slice_num):
slides = tensor_slide(tensor, slice_num)
if slides[0] is None:
return slides
for idx in range(1, len(slides)):
slides[idx] = self.regenerate_position_ids(slides[idx], idx)
return slides
@staticmethod
def regenerate_position_ids(tensor, offset):
if tensor is None:
return None
tensor = tensor.clone()
for i in range(tensor.size(0)):
row = tensor[i]
zero_mask = (row == 0) # 两句拼接情形
if zero_mask.any():
first_zero_idx = torch.argmax(zero_mask.int()).item()
tensor[i, :first_zero_idx] = torch.arange(first_zero_idx)
else:
tensor[i] = tensor[i] - offset
return tensor
def forward(
self,
input_ids: Tensor,
......@@ -209,11 +387,18 @@ class GPTModel(LanguageModule):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
input_ids, labels, position_ids, attention_mask = self.slice_inputs(
input_ids,
labels,
position_ids,
attention_mask)
# Decoder embedding.
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0])
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
......@@ -242,7 +427,7 @@ class GPTModel(LanguageModule):
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
attention_mask=attention_mask[0],
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
......@@ -258,6 +443,36 @@ class GPTModel(LanguageModule):
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
loss = 0
# Multi token prediction module
if self.num_nextn_predict_layers and self.training:
if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight:
output_weight = self.output_layer.weight
output_weight.zero_out_wgrad = True
embedding_weight = self.shared_embedding_or_mtp_embedding_weight() if self.share_mtp_embedding_and_output_weight else None
mtp_hidden_states = hidden_states
for i in range(self.num_nextn_predict_layers):
mtp_hidden_states, mtp_loss = self.mtp_layers[i](
mtp_hidden_states, # [s,b,h]
input_ids[i + 1],
position_ids[i + 1] if position_ids[0] is not None else None,
attention_mask[i + 1] if attention_mask[0] is not None else None,
labels[i + 1] if labels[0] is not None else None,
inference_params,
packed_seq_params,
extra_block_kwargs,
embeding_weight=embedding_weight,
output_weight=output_weight,
)
loss += self.mtp_loss_scale / self.num_nextn_predict_layers * mtp_loss
if self.num_nextn_predict_layers and self.final_layernorm is not None:
# move block main model final norms here
hidden_states = self.final_layernorm(hidden_states)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
......@@ -265,21 +480,20 @@ class GPTModel(LanguageModule):
if has_config_logger_enabled(self.config):
payload = OrderedDict(
{
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'input_ids': input_ids[0],
'position_ids': position_ids[0],
'attention_mask': attention_mask[0],
'decoder_input': decoder_input,
'logits': logits,
}
)
log_config_to_disk(self.config, payload, prefix='input_and_logits')
if labels is None:
if labels[0] is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
loss += self.compute_language_model_loss(labels[0], logits)
return loss
def sharded_state_dict(
......
......@@ -191,6 +191,7 @@ class VocabParallelEmbedding(torch.nn.Module):
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
skip_weight_param_allocation: bool = False
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
......@@ -210,40 +211,52 @@ class VocabParallelEmbedding(torch.nn.Module):
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
self.weight = None
@torch.compile(mode='max-autotune-no-cudagraphs')
def forward(self, input_):
def forward(self, input_, weight=None):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
......@@ -254,10 +267,10 @@ class VocabParallelEmbedding(torch.nn.Module):
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = self.weight[masked_input]
output_parallel = weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, self.weight)
output_parallel = F.embedding(masked_input, weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
......
......@@ -312,3 +312,101 @@ def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
class CheckpointFunctionWithoutOutput(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, checkpoint, *args):
with torch.no_grad():
outputs = run_function(*args)
# Store everything
ctx.save_for_backward(*detach_variable(args))
checkpoint.ctx = ctx
return outputs
@staticmethod
def backward(ctx, *args):
inputs = ctx.saved_tensors
outputs = ctx.outputs
torch.autograd.backward(outputs, args)
ctx.outputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
return (None, None) + grads
class CheckpointWithoutOutput:
def __init__(self):
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
self.outputs = None
def checkpoint(self, run_function, distribute_saved_activations, *args):
self.run_function = run_function
if distribute_saved_activations:
raise RuntimeError(
"CheckpointFunctionWithoutOutput does not support "
"distribute_saved_activations"
)
#Copy the rng states.
self.fwd_cpu_rng_state = torch.get_rng_state()
self.fwd_cuda_rng_state = torch.cuda.get_rng_state()
self.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
outputs = CheckpointFunctionWithoutOutput.apply(run_function, self, *args)
self.outputs = outputs
if isinstance(self.outputs, torch.Tensor):
self.outputs = (self.outputs,)
return outputs
def discard_output(self):
for output in self.outputs:
output.untyped_storage().resize_(0)
def recompute(self, _):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
# Store the current states.
cur_cpu_rng_state = torch.get_rng_state()
cur_cuda_rng_state = torch.cuda.get_rng_state()
cur_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(self.fwd_cpu_rng_state)
_set_cuda_rng_state(self.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(self.fwd_cuda_rng_state_tracker)
with torch.enable_grad():
outputs = self.run_function(*self.ctx.saved_tensors)
self.run_function = None
self.fwd_cpu_rng_state = None
self.fwd_cuda_rng_state = None
self.fwd_cuda_rng_state_tracker = None
# Set the states back to what it was at the start of this function.
torch.set_rng_state(cur_cpu_rng_state)
_set_cuda_rng_state(cur_cuda_rng_state)
get_cuda_rng_tracker().set_states(cur_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
for output, recomputation_output in zip(self.outputs, outputs):
output_size = recomputation_output.untyped_storage().size()
output.untyped_storage().resize_(output_size)
with torch.no_grad():
output.untyped_storage().copy_(recomputation_output.untyped_storage())
self.ctx.outputs = outputs
self.outputs = None
self.ctx = None
......@@ -6,6 +6,18 @@ from typing import Optional
import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
try:
from megatron.core.extensions.transformer_engine import (
fused_permute,
fused_sort_chunks_by_index,
fused_unpermute,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
def switch_load_balancing_loss_func(
......@@ -59,7 +71,6 @@ def switch_load_balancing_loss_func(
def sequence_load_balancing_loss_func(
probs: torch.Tensor,
routing_map: torch.Tensor,
tokens_per_expert: torch.Tensor,
batch_size: int,
seq_length: int,
topk: int,
......@@ -70,25 +81,40 @@ def sequence_load_balancing_loss_func(
Calculate the auxiliary loss in sequence-level by computing the loss for each individual sample.
Refer to the DeepSeek-V2 huggingface repo
(https://huggingface.co/deepseek-ai/DeepSeek-V2) for details.
Args:
probs (torch.Tensor): Softmax probabilities output by the router for each token.
Shape in [num_tokens, num_experts].
routing_map (torch.Tensor): Mapping of tokens to experts assignment.
Shape in [num_tokens, num_experts].
batch_size (int): Batch size to process.
seq_length (int): Sequence length to process.
topk (int): Number of experts to route to for each token.
moe_aux_loss_coeff (float): Scaling coefficient for the auxiliary loss.
sequence_partition_group (optional): The parallel group over which the sequence is
partitioned. If None, no partitioning is applied.
Defaults to None.
Returns:
torch.Tensor: The sequence auxiliary loss for load balancing.
"""
num_sub_sequence = 1
num_experts = probs.shape[1]
probs_for_aux_loss = probs.view(seq_length, batch_size, -1)
routing_map = routing_map.view(seq_length, batch_size, -1)
# If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
# or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
# sequence.
if sequence_partition_group is not None:
# We can keep `aggregated_probs_per_expert` local since we don't need the gradient for
# `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`.
num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
torch.distributed.all_reduce(tokens_per_expert, group=sequence_partition_group)
assert num_sub_sequence == 1, "Do not support sequence aux loss in sequence partition case"
num_experts = probs.shape[1]
seq_length *= num_sub_sequence
probs_for_aux_loss = gather_from_sequence_parallel_region(
probs_for_aux_loss, group=sequence_partition_group
)
probs_for_aux_loss = probs.view(seq_length, batch_size, -1)
cost_coeff = routing_map.view(seq_length, batch_size, -1).sum(dim=0).float()
cost_coeff.div_(seq_length * topk / num_experts)
cost_coeff = routing_map.sum(dim=0, dtype=torch.float).div_(seq_length * topk / num_experts)
seq_aux_loss = (cost_coeff * probs_for_aux_loss.mean(dim=0)).sum(dim=1).mean()
seq_aux_loss *= moe_aux_loss_coeff
......@@ -192,29 +218,60 @@ class MoEAuxLossAutoScaler(torch.autograd.Function):
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
def permute(tokens, routing_map, num_out_tokens: int = None):
def permute(
tokens,
routing_map,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if fused:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens)
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
sorted_indices = sorted_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
......@@ -228,11 +285,19 @@ def unpermute(
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
When drop_and_pad=True, the tensors will have the following properties:
- In routing_map, the number of non-zeros in each column equals to expert capacity
- The size of sorted_indices equals to num_experts * capacity, each split of `capacity`
contains the indices of tokens routed to an expert.
This function exploits these features to use ops that support cuda graph.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
......@@ -240,15 +305,40 @@ def unpermute(
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
"""
if fused:
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape)
_, hidden = restore_shape
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)
# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
# Create an output tensor filled with zeros
......@@ -260,54 +350,72 @@ def unpermute(
return output_tokens
def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor):
def sort_chunks_by_idxs(
input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False
):
"""Split and sort the input tensor based on the split_sizes and sorted indices."""
if fused:
if not HAVE_TE or fused_sort_chunks_by_index is None:
raise ValueError(
"fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
)
return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs)
input = torch.split(input, split_sizes.tolist(), dim=0)
output = torch.cat([input[i] for i in sorted_idxs], dim=0)
output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
return output
def device_limited_topk(
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
moe_router_topk_limited_devices: int,
num_groups: int,
group_topk: int,
):
"""Perform top-k routing on a subset of expert parallel ranks.
"""Perform top-k routing on a subset of expert groups.
Selects N ranks for each token, then conducts top-k selection among experts on these devices.
See DeepSeek-V2 technical report (https://arxiv.org/pdf/2405.04434) for details.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
Args:
scores (torch.Tensor): Softmax scores from the router.
scores (torch.Tensor): Softmax scores generated by the router.
topk (int): The number of experts to select for each token.
num_tokens (int): The number of tokens.
num_experts (int): The number of experts.
moe_router_topk_limited_devices (int): Number of expert parallel ranks to consider for
each token during routing. None means no device limitation.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of groups selected for each token.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
"""
# Organize the experts into groups
num_group = (
parallel_state.get_expert_model_parallel_world_size()
) # num_group equals to expert parallel size
group_scores = scores.view(num_tokens, num_group, -1).max(dim=-1).values
group_idx = torch.topk(group_scores, k=moe_router_topk_limited_devices, dim=-1, sorted=False)[1]
group_scores = scores.view(num_tokens, num_groups, -1).topk(2, dim=-1)[0].sum(dim=-1)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_group, num_experts // num_group)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), 0.0)
masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
return probs, top_indices
......@@ -320,26 +428,31 @@ def topk_softmax_with_capacity(
pad_to_capacity: bool = False,
drop_policy: str = "probs",
use_pre_softmax: bool = False,
moe_router_topk_limited_devices: int = None,
moe_router_topk_scaling_factor: float = None,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number
capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
If "prob", the tokens with the lowest probabilities will be dropped.
If "position", tokens at the end of each batch will be dropped.
use_pre_softmax (bool): Whether to apply softmax before top-k selection.
moe_router_topk_limited_devices (int): Number of expert parallel ranks to consider for
each token during routing. None means no device limitation.
moe_router_topk_scaling_factor (float): Scaling factor for routing score in top-k
selection, only works when use_pre_softmax enabled.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of selected groups for each token.
scaling_factor (float): Scaling factor of routing score in top-k selection.
deterministic_mode (bool): Deprecated.
score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
expert_bias (torch.Tensor): The bias added to logits for expert routing.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
......@@ -351,38 +464,42 @@ def topk_softmax_with_capacity(
the number of local tokens assigned to each expert before dropping and padding.
"""
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]
if use_pre_softmax:
# Pre softmax
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
if moe_router_topk_limited_devices:
probs, top_indices = device_limited_topk(
scores, topk, num_tokens, num_experts, moe_router_topk_limited_devices
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
probs, top_indices = torch.topk(scores, k=topk, dim=1)
return torch.topk(scores, k=topk, dim=1)
# Normalize the probs.
if moe_router_topk_scaling_factor:
probs = probs * moe_router_topk_scaling_factor
else:
# Post softmax
if topk == 1:
# Requires applying softmax before selecting the top-k when k is 1,
# since softmax on a [num_tokens, 1] would yield a zero gradient.
raise ValueError("Please use --moe-router-pre-softmax when topk is 1.")
assert (
moe_router_topk_scaling_factor is None
), "moe_router_topk_scaling_factor is not supported with post-softmax"
if moe_router_topk_limited_devices:
scores, top_indices = device_limited_topk(
logits, topk, num_tokens, num_experts, moe_router_topk_limited_devices
)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = torch.topk(logits, k=topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
# TODO Try using element-wise operations instead of scatter?
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
......@@ -516,3 +633,23 @@ def track_moe_metrics(
)
clear_aux_losses_tracker()
def get_updated_expert_bias(tokens_per_expert, expert_bias, expert_bias_update_rate):
"""Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#
Args:
tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert.
expert_bias (torch.Tensor): The bias for each expert.
expert_bias_udpate_rate (float): The update rate for the expert bias.
"""
with torch.no_grad():
# All Reduce Across TPxCPxDP group
torch.distributed.all_reduce(
tokens_per_expert,
group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
)
average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1]
offset = average_tokens - tokens_per_expert
updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate
return updated_expert_bias
......@@ -102,8 +102,23 @@ class TopKRouter(Router):
super().__init__(config=config)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.score_function = self.config.moe_router_score_function
self.input_jitter = None
self.enable_expert_bias = self.config.moe_router_enable_expert_bias
if self.enable_expert_bias:
self.register_buffer(
'local_tokens_per_expert',
torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
'expert_bias', torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
)
else:
self.local_tokens_per_expert = None
self.expert_bias = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
......@@ -154,9 +169,12 @@ class TopKRouter(Router):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
if self.training:
......@@ -183,9 +201,12 @@ class TopKRouter(Router):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
if self.training:
......@@ -194,7 +215,6 @@ class TopKRouter(Router):
sequence_load_balancing_loss_func,
probs=scores,
routing_map=routing_map,
tokens_per_expert=tokens_per_expert,
batch_size=bsz,
seq_length=seq_length,
topk=self.topk,
......@@ -210,11 +230,13 @@ class TopKRouter(Router):
):
"""Calculate auxiliary loss, attach gradient function to activation and add to logging."""
moe_aux_loss_coeff = self.config.moe_aux_loss_coeff
if moe_aux_loss_coeff == 0:
return activation
sequence_partition_group = None
if self.config.moe_token_dispatcher_type == "alltoall_seq":
sequence_partition_group = parallel_state.get_context_parallel_group()
moe_aux_loss_coeff /= parallel_state.get_tensor_model_parallel_world_size()
else:
elif parallel_state.get_tensor_and_context_parallel_world_size() > 1:
sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
aux_loss = load_balancing_loss_func(
......@@ -309,12 +331,21 @@ class TopKRouter(Router):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
# Prevent extra local tokens accumulation on evaluation or activation recomputation
if self.enable_expert_bias and torch.is_grad_enabled():
with torch.no_grad():
self.local_tokens_per_expert += routing_map.sum(dim=0)
return scores, routing_map
def forward(self, input: torch.Tensor):
......
......@@ -308,11 +308,11 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
input_chunk_idxs = torch.arange(self.num_experts * self.tp_size)
# [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = (
input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist()
input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel()
)
# [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = (
input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist()
input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel()
)
# Token drop and padding.
......
import warnings
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredicationSubmodules, \
MultiTokenPredictor
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TENorm
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_mtp_spec(transformer_layer, use_te=False):
"""
Multi Token Predication Layer Specification.
"""
use_te = use_te & HAVE_TE
mtp_sepc = ModuleSpec(
module=MultiTokenPredictor,
submodules=MultiTokenPredicationSubmodules(
embedding=None,
enorm=TENorm if use_te else LNImpl,
hnorm=TENorm if use_te else LNImpl,
eh_proj=TEColumnParallelLinear if use_te else ColumnParallelLinear,
transformer_layer=transformer_layer,
final_layernorm=TENorm if use_te else LNImpl,
output_layer=None,
)
)
return mtp_sepc
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import logging
from dataclasses import dataclass
from typing import Union, Optional, Literal
import torch
from torch import Tensor
from megatron.core import tensor_parallel, InferenceParams
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.module import MegatronModule
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module
from megatron.core.tensor_parallel.random import CheckpointWithoutOutput
@dataclass
class MultiTokenPredicationSubmodules:
embedding: Union[ModuleSpec, type] = None
output_layer: Union[ModuleSpec, type] = None
eh_proj: Union[ModuleSpec, type] = None
enorm: Union[ModuleSpec, type] = None
hnorm: Union[ModuleSpec, type] = None
transformer_layer: Union[ModuleSpec, type] = None
final_layernorm: Union[ModuleSpec, type] = None
class MultiTokenPredictor(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MultiTokenPredicationSubmodules,
vocab_size: int,
max_sequence_length: int,
layer_number: int = 1,
hidden_dropout: float = None,
pre_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
rotary_percent: float = 1.0,
rotary_base: int = 10000,
seq_len_interpolation_factor: Optional[float] = None,
share_mtp_embedding_and_output_weight=True,
recompute_mtp_norm=False,
recompute_mtp_layer=False,
add_output_layer_bias=False
):
super().__init__(config=config)
self.config = config
self.submodules = submodules
self.layer_number = layer_number
self.hidden_dropout = hidden_dropout
self.hidden_size = self.config.hidden_size
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.parallel_output = parallel_output
self.position_embedding_type = position_embedding_type
# share with main model
self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight
self.recompute_layer_norm = recompute_mtp_norm
self.recompute_mtp_layer = recompute_mtp_layer
self.add_output_layer_bias = add_output_layer_bias
self.embedding = LanguageModelEmbedding(
config=self.config,
vocab_size=self.vocab_size,
max_sequence_length=self.max_sequence_length,
position_embedding_type=self.position_embedding_type,
skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight
)
if self.position_embedding_type == 'rope':
self.rotary_pos_emb = RotaryEmbedding(
kv_channels=self.config.kv_channels,
rotary_percent=rotary_percent,
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
use_cpu_initialization=self.config.use_cpu_initialization,
)
self.enorm = build_module(
self.submodules.enorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.hnorm = build_module(
self.submodules.hnorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.eh_proj = build_module(
self.submodules.eh_proj,
self.hidden_size + self.hidden_size,
self.hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='eh',
)
self.transformer_layer = build_module(
self.submodules.transformer_layer,
config=self.config,
)
if self.submodules.final_layernorm:
self.final_layernorm = build_module(
self.submodules.final_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer = []
self.grad_output_buffer = []
else:
self.embedding_activation_buffer = None
self.grad_output_buffer = None
self.output_layer = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
self.vocab_size,
config=config,
init_method=config.init_method,
bias=self.add_output_layer_bias,
skip_bias_add=False,
gather_output=not self.parallel_output,
skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight,
embedding_activation_buffer=self.embedding_activation_buffer,
grad_output_buffer=self.grad_output_buffer,
)
def forward(
self,
hidden_input_ids: Tensor,
embed_input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
embeding_weight: Optional[torch.Tensor] = None,
output_weight: Optional[torch.Tensor] = None,
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input = self.embedding(
input_ids=embed_input_ids,
position_ids=position_ids,
weight=embeding_weight,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
rotary_seq_len = decoder_input.size(0)
if self.config.sequence_parallel:
rotary_seq_len *= self.config.tensor_model_parallel_size
rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
self.hnorm_ckpt = CheckpointWithoutOutput()
hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids)
else:
enorm_output = self.enorm(decoder_input)
hnorm_output = self.hnorm(hidden_input_ids)
# [s, b, h] -> [s, b, 2h]
hidden_states = torch.concat(
[hnorm_output,
enorm_output],
dim=-1
)
if self.recompute_layer_norm:
self.enorm_ckpt.discard_output()
self.hnorm_ckpt.discard_output()
hidden_states.register_hook(self.enorm_ckpt.recompute)
hidden_states.register_hook(self.hnorm_ckpt.recompute)
# hidden_states -> [s, b, h]
hidden_states, _ = self.eh_proj(hidden_states)
if self.config.tensor_model_parallel_size > 1:
hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states)
if self.config.sequence_parallel:
hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states)
if self.recompute_mtp_layer:
hidden_states, context = tensor_parallel.checkpoint(
self.transformer_layer,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
inference_params,
packed_seq_params,
)
else:
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
**(extra_block_kwargs or {}),
)
# Final layer norm.
if self.final_layernorm is not None:
if self.recompute_layer_norm:
self.finalnorm_ckpt = CheckpointWithoutOutput()
finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states)
else:
finalnorm_output = self.final_layernorm(hidden_states)
else:
finalnorm_output = hidden_states
logits, _ = self.output_layer(finalnorm_output, weight=output_weight)
if self.recompute_layer_norm:
self.finalnorm_ckpt.discard_output()
logits.register_hook(self.finalnorm_ckpt.recompute)
if labels is None:
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
loss = self.compute_language_model_loss(labels, logits)
return hidden_states, loss
def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
loss = fused_vocab_parallel_cross_entropy(logits, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
\ No newline at end of file
......@@ -245,7 +245,9 @@ class TransformerBlock(MegatronModule):
# @TODO: add back standalone_embedding_stage (see issue #293)
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior
if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
move_final_norm_out_of_block = args.num_nextn_predict_layers > 0
if self.submodules.layer_norm and self.post_process and self.post_layer_norm and not move_final_norm_out_of_block:
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
......
......@@ -45,14 +45,14 @@ class TransformerConfig(ModelParallelConfig):
If attention backend is local we use the local pytorch implementation in mcore.
Users can specify exact backend by changing this config. """
num_query_groups: int = None
num_query_groups: Optional[int] = None
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: int = None
ffn_hidden_size: Optional[int] = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size
if not provided."""
kv_channels: int = None
kv_channels: Optional[int] = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
......@@ -93,7 +93,7 @@ class TransformerConfig(ModelParallelConfig):
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts: int = None
num_moe_experts: Optional[int] = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
......@@ -105,7 +105,7 @@ class TransformerConfig(ModelParallelConfig):
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
normalization: bool = "LayerNorm"
normalization: str = "LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False
......@@ -124,13 +124,13 @@ class TransformerConfig(ModelParallelConfig):
####################
# initialization
####################
init_method: Callable = None
init_method: Optional[Callable] = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method: Callable = None
output_layer_init_method: Optional[Callable] = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
......@@ -176,7 +176,7 @@ class TransformerConfig(ModelParallelConfig):
####################
# activation recomputation
####################
recompute_granularity: str = None
recompute_granularity: Optional[str] = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
......@@ -186,7 +186,7 @@ class TransformerConfig(ModelParallelConfig):
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method: str = None
recompute_method: Optional[str] = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
......@@ -194,19 +194,19 @@ class TransformerConfig(ModelParallelConfig):
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: int = None
recompute_num_layers: Optional[int] = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations: bool = None
distribute_saved_activations: Optional[bool] = None
"""If True, distribute recomputed activations across the model parallel group."""
####################
# fp8 related
####################
fp8: str = None
fp8: Optional[str] = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
......@@ -245,7 +245,7 @@ class TransformerConfig(ModelParallelConfig):
####################
# MoE related
####################
moe_shared_expert_intermediate_size: int = None
moe_shared_expert_intermediate_size: Optional[int] = None
"""Shared expert total ffn hidden size.
It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if
there are multiple shared experts.
......@@ -255,14 +255,12 @@ class TransformerConfig(ModelParallelConfig):
"""Enable overlapping between shared expert computations and dispatcher communications.
Without this, the shared epxerts execute after the routed experts."""
moe_layer_freq: int = 1
moe_layer_freq: Union[int, List[int]] = 1
"""Frequency between MoE layers and Dense layers. Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers.
- A string containing a Python list expression that defines a custom pattern, e.g.:
"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0]
where 1 indicates an expert layer and 0 indicates a dense layer."""
- A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""
moe_ffn_hidden_size: int = None
moe_ffn_hidden_size: Optional[int] = None
"""MoE Feed-Forward Network hidden size"""
moe_router_load_balancing_type: str = "aux_loss"
......@@ -274,19 +272,52 @@ class TransformerConfig(ModelParallelConfig):
moe_router_topk: int = 2
"""Number of experts to route to for each token."""
moe_router_topk_limited_devices: int = None
"""Number of expert parallel ranks to consider for each token during routing. Perform top-k
routing on a subset of expert parallel ranks by first selecting N ranks for each token, then
conducting top-k selection among experts on these devices. None means no device limitation."""
moe_router_topk_limited_devices: Optional[int] = None
"""Number of EP ranks to consider for each token in group-limited routing,
DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
"""
moe_router_num_groups: Optional[int] = None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
"""
moe_router_group_topk: Optional[int] = None
"""Number of selected groups for group-limited routing."""
moe_router_pre_softmax: bool = False
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection.
By default, softmax is done after top-k."""
moe_router_topk_scaling_factor: float = None
moe_router_topk_scaling_factor: Optional[float] = None
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""
moe_router_score_function: str = "softmax"
"""Score function for MoE routing. Can be "softmax" or "sigmoid"."""
moe_router_enable_expert_bias: bool = False
"""TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy.
The routing decision is based on the sum of the routing scores and the expert bias.
See https://arxiv.org/abs/2408.15664 for details."""
moe_router_bias_update_rate: float = 1e-3
"""The expert bias is updated based on the number of assigned tokens to each expert
in a global batch, where the bias is increased for the experts with less assigned tokens
and decreased for the experts with more assigned tokens.
The default value 1e-3 is same as that used in DeepSeekV3."""
moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
......@@ -300,13 +331,13 @@ class TransformerConfig(ModelParallelConfig):
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff: float = None # 1e-3 would be a good start value for z-loss
moe_z_loss_coeff: Optional[float] = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: float = None
moe_input_jitter_eps: Optional[float] = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False # TODO: Support token dropping.
moe_token_dropping: bool = False
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
......@@ -318,7 +349,7 @@ class TransformerConfig(ModelParallelConfig):
moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor: float = None
moe_expert_capacity_factor: Optional[float] = None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token
will be dropped. The default is None."""
......@@ -339,7 +370,7 @@ class TransformerConfig(ModelParallelConfig):
##################
# Context Parallel
##################
cp_comm_type: Union[str, List[str]] = None
cp_comm_type: Optional[Union[str, List[str]]] = None
"""Inter-gpu communication type for context parallelism.
str: all layers share same communication type.
List[str]: each layer has its separate communication type.
......@@ -577,6 +608,12 @@ class TransformerConfig(ModelParallelConfig):
"alltoall_seq dispatcher not support different TP size for MoE and Dense layer."
)
if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid":
raise ValueError(
"Expert bias for aux-loss-free routing only supports sigmoid score function."
"Please set --moe-router-score-function sigmoid for sigmoid score function."
)
if self.num_moe_experts and self.fp8:
# TE version below 1.7.0 will raise Error when handle zeros tokens for expert
if not is_te_min_version("1.7.0.dev0"):
......
......@@ -1451,3 +1451,33 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
batch[key] = val
return batch
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
......@@ -47,6 +47,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_mla_args(parser)
parser = _add_mtp_args(parser)
parser = _add_logging_args(parser)
parser = _add_straggler_detector_args(parser)
parser = _add_inference_args(parser)
......@@ -403,10 +404,12 @@ def validate_args(args, defaults={}):
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
args.main_grads_dtype = dtype_map[args.main_grads_dtype]
args.main_params_dtype = dtype_map[args.main_params_dtype]
args.exp_avg_dtype = dtype_map[args.exp_avg_dtype]
args.exp_avg_sq_dtype = dtype_map[args.exp_avg_sq_dtype]
map_dtype = lambda d: d if isinstance(d, torch.dtype) else dtype_map[d]
args.main_grads_dtype = map_dtype(args.main_grads_dtype)
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
if args.fp8_param_gather:
assert args.use_distributed_optimizer, \
......@@ -536,7 +539,9 @@ def validate_args(args, defaults={}):
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
assert args.max_position_embeddings >= args.seq_length, \
f"max_position_embeddings ({args.max_position_embeddings}) must be greater than " \
f"or equal to seq_length ({args.seq_length})."
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
......@@ -2121,17 +2126,39 @@ def _add_moe_args(parser):
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
# Router arguments
group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'],
choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none', 'noaux_tc'],
default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
help='Determines the load balancing strategy for the router. '
'"aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; '
'"seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; '
'"sinkhorn" corresponds to the balancing algorithm used in S-BASE, '
'"none" implies no load balancing, '
'"noaux_tc" corresponds to no aux loss load balancing method in DeepSeekV3. '
'The default is "aux_loss".')
group.add_argument('--moe-router-score-function', type=str,
choices=['softmax', 'sigmoid'],
default='softmax',
help='Score function for MoE TopK routing. Can be "softmax" or "sigmoid".')
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-router-topk-limited-devices', type=int, default=None,
help='Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. Default is None, which means no limited devices.')
group.add_argument('--moe-router-num-groups', type=int, default=None,
help='Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen.'
'Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)')
group.add_argument('--moe-router-group-topk', type=int, default=None,
help='Number of selected groups for group-limited routing.')
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-router-enable-expert-bias', action='store_true',
help='TopK routing with dynamic expert bias in the aux-loss-free load balancing strategy. '
'The routing decision is based on the sum of the routing scores and the expert bias. '
'See https://arxiv.org/abs/2408.15664 for details.')
group.add_argument('--moe-router-bias-update-rate', type=float, default=1e-3,
help='Expert bias update rate in the aux-loss-free load balancing strategy. '
'The expert bias is updated based on the number of assigned tokens to each expert in a global batch, '
'where the bias is increased for the experts with less assigned tokens and decreased for the experts with more assigned tokens. '
'The default value 1e-3 is same as that used in DeepSeekV3.')
group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true',
help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
......@@ -2160,6 +2187,8 @@ def _add_moe_args(parser):
group.add_argument('--moe-use-upcycling', action='store_true',
help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. '
'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.')
group.add_argument('--moe-permute-fusion', action='store_true',
help='Fuse token rearrangement ops during token dispatching.')
return parser
......@@ -2180,6 +2209,20 @@ def _add_mla_args(parser):
return parser
def _add_mtp_args(parser):
group = parser.add_argument_group(title='multi token prediction')
group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num')
group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale')
group.add_argument('--recompute-mtp-norm', action='store_true', default=False,
help='Multi-Token prediction recompute norm')
group.add_argument('--recompute-mtp-layer', action='store_true', default=False,
help='Multi-Token prediction recompute layer')
group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False,
help='Main model share embedding and output weight with mtp layer.')
return parser
def _add_experimental_args(parser):
group = parser.add_argument_group(title='experimental')
......
......@@ -8,6 +8,8 @@ from datetime import datetime
import torch
from typing import List, Optional, Union
try:
from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_l2norm
except ImportError:
......@@ -417,22 +419,36 @@ def get_batch_on_this_tp_rank(data_iterator):
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_last_stage():
if args.num_nextn_predict_layers:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(batch['position_ids'])
else:
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
tokens=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.float32,
device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
(args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers,
args.seq_length + args.num_nextn_predict_layers), dtype = torch.bool,
device = torch.cuda.current_device()
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
position_ids=torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers),
dtype = torch.int64,
device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
......@@ -450,13 +466,20 @@ def get_batch_on_this_tp_rank(data_iterator):
_broadcast(position_ids)
elif mpu.is_pipeline_last_stage():
tokens=None
position_ids=None
if args.num_nextn_predict_layers:
_broadcast(tokens)
else:
tokens = None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
if args.reset_position_ids or args.num_nextn_predict_layers:
_broadcast(position_ids)
else:
position_ids = None
batch = {
'tokens': tokens,
'labels': labels,
......@@ -470,3 +493,33 @@ def get_batch_on_this_tp_rank(data_iterator):
def update_use_dist_ckpt(args):
args.use_dist_ckpt = args.ckpt_format != "torch"
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
......@@ -35,6 +35,9 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
import torch._dynamo
torch._dynamo.config.suppress_errors = True
......@@ -112,6 +115,17 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
except:
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the decoder layer spec
if use_te:
mtp_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
else:
mtp_transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
with build_model_context(**build_model_context_args):
model = GPTModel(
config=config,
......@@ -126,7 +140,13 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling
rope_scaling=args.use_rope_scaling,
mtp_spec=mtp_spec,
num_nextn_predict_layers=args.num_nextn_predict_layers,
share_mtp_embedding_and_output_weight=args.share_mtp_embedding_and_output_weight,
recompute_mtp_norm=args.recompute_mtp_norm,
recompute_mtp_layer=args.recompute_mtp_layer,
mtp_loss_scale=args.mtp_loss_scale
)
model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model)
......@@ -249,7 +269,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length,
sequence_length=args.seq_length + args.num_nextn_predict_layers,
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment