"""Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
Args:
tokens (torch.Tensor): The input token tensor.
indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk].
num_out_tokens (int, optional): The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
Returns:
torch.Tensor: The permuted tensor.
torch.Tensor: The sorted_indices corresponding permuted tensor.
"""Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
Args:
permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted.
sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens.
probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities.
padded_mode (bool, optional): If True, indicating the indices are padded to [num_expert, capacity] to denote selected tokens per expert. Defaults to False.
restore_shape (torch.Size, optional): The input shape before permutation, only used in padding mode. Defaults to None.
Returns:
torch.Tensor: The unpermuted tokens, optionally merged with probabilities.
Unpermutes a padded permuted tokens based on sorted indices and merges the tokens with their corresponding probabilities.
This function takes a tensor of permuted tokens and reorders them according to the provided indices. It also combines the tokens with their associated probabilities.
Parameters:
permuted_tokens (torch.Tensor): A 2D tensor containing permuted tokens.
indices (torch.Tensor): A tensor with shape [num_expert, capacity], indicating the selected tokens for each expert.
probs (torch.Tensor): A tensor with the same shape as indices, containing probabilities corresponding to each token.
restore_shape (torch.Size): The target shape for the unpermuted tokens tensor.
Returns:
torch.Tensor: A tensor of unpermuted tokens, merged with their probabilities.
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
assertself.num_local_experts>0,"Expected at least one expert"
self.local_expert_indices=local_expert_indices
assertlen(self.local_expert_indices)>0,"Expected at least one local expert index"
self.router_topk=config.moe_router_topk
self.add_bias=config.add_bias_linear
# self.local_probs: probs of global token assignment to local experts.
self.local_probs=None
# self.indices: The indices of `local_indices` (which holds the un-sorted expert indices of tokens that local expert can process) that give its sorted order along dim 0.
self.indices=None
# self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where each element is True if it's between the local_expert_indices. Only useful when cross device token permutation is enabled and **AllGahter** is performed.
Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts.
Args:
indices (torch.Tensor): Tensor of indices mapping tokens to experts.
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon:float=1e-5
"""Epsilon value for any LayerNorm operations."""
layernorm_zero_centered_gamma:bool=False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear:bool=True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""
add_qkv_bias:bool=False
"""Add a bias term only for QKV projections."""
gated_linear_unit:bool=False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func:Callable=F.gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store:bool=False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts:int=None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved:bool=False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size:Optional[Tuple[int,int]]=None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
#normalization: bool = "LayerNorm"
normalization:bool="RMSNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm:bool=False
"""Whether to apply LayerNorm to the query and key embeddings."""
test_mode:bool=False
"""Whether to run real-time tests."""
calculate_per_token_loss:bool=False
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
####################
# initialization
####################
init_method:Callable=None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method:Callable=None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
init_method_std:float=0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
####################
# mixed-precision
####################
apply_query_key_layer_scaling:bool=False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""
attention_softmax_in_fp32:bool=True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
####################
# fusion
####################
bias_activation_fusion:bool=False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion:bool=False
"""If True, uses softmax fusion."""
persist_layer_norm:bool=False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm:bool=False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion:bool=False# TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion:bool=False
"""If True, use fused RoPE kernel."""
####################
# activation recomputation
####################
recompute_granularity:str=None
recompute_granularity:str=None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where only the memory intensive part of attention is checkpointed.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method:str=None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers:int=None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations:bool=None
"""If True, distribute recomputed activations across the model parallel group."""
####################
# fp8 related
####################
fp8:str=None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_margin:int=0
"""Margin for the scaling factor computation."""
fp8_interval:int=1
"""Controls how often the scaling factor is recomputed."""
fp8_amax_history_len:int=1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo:str="most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad:bool=True
"""When set to False, override FP8 config options and do the wgrad computation in higher precision."""
fp8_dot_product_attention:bool=False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention:bool=False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
####################
# MoE related
####################
moe_router_load_balancing_type:str="aux_loss"
"""Determines the load balancing strategy for the router. "aux_loss" corresponds to the load
balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing
algorithm used in S-BASE, and "none" implies no load balancing."""
moe_router_topk:int=2
"""Number of experts to route to for each token."""
moe_grouped_gemm:bool=False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_aux_loss_coeff:float=0# 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff:float=None# 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps:float=None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping:bool=False# TODO: Support token dropping.
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type:str="allgather"
"""The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather' and 'alltoall'."""
moe_per_layer_logging:bool=False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor:float=None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token will be dropped. The default is None."""
moe_pad_expert_input_to_capacity:bool=False
"""moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match the expert capacity length, effective only after the moe_expert_capacity_factor is set. The default setting is False."""
moe_token_drop_policy:str='probs'
"""The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
"""
moe_layer_recompute:bool=False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding:bool=True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache:bool=False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
enable_cuda_graph:bool=False
"""When set to true, TransformerLayer blocks are wrapped with CUDA graph."""
def__post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
super().__post_init__()
ifself.fp16andself.bf16:
raiseValueError(
f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'