Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -21,6 +21,7 @@ from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
# pylint: disable=missing-class-docstring
@dataclass
class MLPSubmodules:
linear_fc1: Union[ModuleSpec, type] = None
......@@ -129,6 +130,7 @@ class MLP(MegatronModule):
return output, output_bias
# pylint: disable=missing-function-docstring
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
......@@ -136,7 +138,9 @@ class MLP(MegatronModule):
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
if self.config.gated_linear_unit and name == 'linear_fc1':
assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
# NOTE: In custom FSDP, we can have no weight in local.
if not self.config.use_custom_fsdp:
assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
for k, v in sub_sd.items():
if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
......@@ -144,6 +148,7 @@ class MLP(MegatronModule):
return sharded_state_dict
# pylint: disable=missing-function-docstring
def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
# We must split the tensor into 2 parts, each sharded separately.
# This requires a ShardedTensorFactory which `chunk`s during saving
......@@ -258,4 +263,5 @@ def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
sh_ten_build_fn,
sh_ten_merge_fn,
original_sh_ten.replica_id,
flattened_range=original_sh_ten.flattened_range,
)
File mode changed from 100755 to 100644
# Megatron Core MoE Key Features
# Megatron Core MoE
Megatron-Core offers rich parallelism mappings, combining Expert Parallelism with tensor, data, sequence, and pipeline parallelism. This boosts Mixtral 8X7B bf16 training to achieve **468 TFLOPS** as of MCore v0.9.
Megatron-Core MoE provides comprehensive parallelism strategies, seamlessly integrating Expert Parallelism with tensor, data, sequence, and pipeline parallelism. With MCore v0.9, we've achieved remarkable performance of **468 TFLOPS** for Mixtral 8X7B bf16 training. Additionally, we support state-of-the-art MoE model architectures including DeepSeek-V3 and Qwen-MoE.
### What's New
- **Support for DeepSeek-V3 architecture**
- Enable TP for MLA and DeepSeek-V3
- Support aux-loss-free load balancing strategy
- Support node-limited routing
- **Support DeepSeek's DeepEP for efficient token dispatching and combining**
- Add fusion for token permutation and unpermutation
- Support Uneven virtual pipeline parallel split
### Parallelism
- **Expert Parallelism**
......@@ -11,6 +19,7 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit
- **Context Parallelism**:
- Split the sequence dimension to support long context training.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/CP for handling larger MoE variants.
- **MoE Parallel Folding**: Support for setting different parallelism strategies for Attention and MoE components, enabling more flexible and efficient model sharding. See detailed documentation below.
- **Full distributed optimizer support.**
### Router and Load Balancing
......@@ -19,8 +28,10 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss
- Aux-loss-free load balancing strategy
### Performance Optimizations
- (Experimental) **DeepEP** is integrated for efficient token communication in large-scale MoE training.
- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
- Performance improvements for larger MoE models
......@@ -30,49 +41,21 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit
### Token Dispatch Mechanism
- Dropless / No token drop
- Token drop, with or without padding to capacity
- Token permutation / Unpermutation fusion
### Ease of use
- Checkpoint converter for Mixtral models, see the [example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mixtral) for details.
- MoE Layer Frequency to customize the hybrid MoE/Dense layer architecture
- Distributed checkpoining
- Per-layer logging
- Upcycling Support
- Granular upcycling
## Upcoming features
- New Parallelism for Large-scale MoE training
- FP8 support for GroupedGEMM
- Token permutation / Unpermutation fusion
- TopK Router Fusion
- MoE Layer Frequency
- Multi-token Prediction
# User Guide
### MoE Related Arguments
| Item | Description |
| --- | --- |
| --num-experts | Number of Experts in MoE (None means no MoE) |
| --expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| --moe-ffn-hidden-size | MoE Feed-Forward Network hidden size. Default is None. |
| --expert-tensor-parallel-size | Degree of tensor model parallelism of expert layer. Default is same to --tensor-model-parallel-size. |
| --moe-layer-freq | Frequency between MoE layers and Dense layers. Accepts either: 1) An integer N for 1:N ratio (one expert layer for every N-1 dense layers), 2) A string "N" for the same ratio, or 3) A string with Python list expression for custom patterns like `([1]*3+[0]*1)*3` which gives [1,1,1,0,1,1,1,0,1,1,1,0] where 1=expert layer and 0=dense layer. Examples: `([0]+[1]*23)` for 1 dense layer followed by 23 experts layers, `([1]*3+[0]*2)*2` for three expert layers followed by two dense layers, repeated twice. Default is 1. |
| --moe-grouped-gemm | When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine. |
| --moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| --moe-router-topk | Number of experts to route to for each token. The default is 2. |
| --moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| --moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| --moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| --moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather", "alltoall" and "alltoall_seq". Default is "allgather". We recommend using 'alltoall' if expert parallelism is applied. We have upgraded the "alltoall" dispatcher in place during MCore v0.9, while retaining the original implementation, renamed as "alltoall_seq".|
| --moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| --moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
| --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. |
| --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. |
| --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. |
| --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. |
| --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.|
## Usage
### Quick Start
......@@ -82,6 +65,7 @@ To train a top-2 MoE model with 8 experts and auxiliary loss, include the follow
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-permute-fusion
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
......@@ -172,6 +156,55 @@ We currently only support the default upcycling strategy, which duplicates the e
Note: The MoE model structure is defined through script arguments. All MoE-related arguments (such as `--num-experts`) can be customized; however, other model structure arguments must be consistent with those of the dense model.
### Leverage DeepSeek's DeepEP for High-Performance Cross-Node Token Dispatching
- [DeepSeek-DeepEP](https://github.com/deepseek-ai/deepep) provides a highly optimized implementation for MoE token dispatching and combining operations, specifically designed for large-scale MoE training scenarios.
- DeepEP is particularly recommended for training large-scale, fine-grained MoE architectures such as DeepSeek-V3 and other advanced MoE models.
- To enable DeepEP in your training configuration, simply set `--moe-token-dispatcher-type=flex` and `--moe-enable-deepep` in your command line arguments.
### MoE Related Arguments
| Item | Description |
| --- | --- |
| --num-experts | Number of Experts in MoE (None means no MoE) |
| --expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| --moe-ffn-hidden-size | MoE Feed-Forward Network hidden size. Default is None. |
<details>
<summary> View all MoE related arguments. </summary>
| Item | Description |
| --- | --- |
| --num-experts | Number of Experts in MoE (None means no MoE) |
| --expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| --moe-ffn-hidden-size | MoE Feed-Forward Network hidden size. Default is None. |
| --expert-tensor-parallel-size | Degree of tensor model parallelism of expert layer. Default is same to --tensor-model-parallel-size. |
| --moe-layer-freq | Frequency between MoE layers and Dense layers. Accepts either: 1) An integer N for 1:N ratio (one expert layer for every N-1 dense layers), 2) A string "N" for the same ratio, or 3) A string with Python list expression for custom patterns like `([1]*3+[0]*1)*3` which gives [1,1,1,0,1,1,1,0,1,1,1,0] where 1=expert layer and 0=dense layer. Examples: `([0]+[1]*23)` for 1 dense layer followed by 23 experts layers, `([1]*3+[0]*2)*2` for three expert layers followed by two dense layers, repeated twice. Default is 1. |
| --moe-grouped-gemm | When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine. |
| --moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| --moe-router-topk | Number of experts to route to for each token. The default is 2. |
| --moe-router-score-function | Score function for MoE routing. Can be "softmax" or "sigmoid". Default is "softmax". |
| --moe-router-pre-softmax | Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k. |
| --moe-router-num-groups | Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen. Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)) |
| --moe-router-group-topk | Number of selected groups for group-limited routing. |
| --moe-router-topk-scaling-factor | Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling. |
| --moe-router-enable-expert-bias | TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy. The routing decision is based on the sum of the routing scores and the expert bias. See https://arxiv.org/abs/2408.15664 for details. |
| --moe-router-bias-update-rate | The expert bias is updated based on the number of assigned tokens to each expert in a global batch, where the bias is increased for experts with less assigned tokens and decreased for experts with more assigned tokens. Default is 1e-3 same as that used in DeepSeekV3. |
| --moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| --moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| --moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| --moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather", "alltoall" and "alltoall_seq". Default is "allgather". We recommend using 'alltoall' if expert parallelism is applied. We have upgraded the "alltoall" dispatcher in place during MCore v0.9, while retaining the original implementation, renamed as "alltoall_seq".|
| --moe-enable-deepep | (Experimental) Enable DeepSeek/DeepEP for efficient token dispatching and combine in MoE models. Only works with flex token dispatcher by setting --moe-token-dispatcher-type=flex. |
| --moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| --moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
| --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. |
| --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. |
| --moe-permute-fusion | Fuse token rearrangement ops during token dispatching. |
| --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. |
| --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. |
| --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.|
</details>
## MoE training example:
<details>
<summary>Click here. </summary>
......@@ -233,6 +266,7 @@ MOE_ARGS=(
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
--moe-permute-fusion
)
DATA_ARGS=(
......@@ -328,13 +362,13 @@ Here we provide some general rules to get better performance:
- In practice, EP8TP1 is better than EP4TP2 for 8x7B.
5. Enable Context Parallelism for long context training.
- The efficiency of CP largely depends on whether its communication can be overlapped with computation.
- Emperically, use CP when sequence length >= 8K.
- Empirically, use CP when sequence length >= 8K.
### MoE Parallel Folding
MoE Parallel Folding separates the MoE related parallel groups from Dense groups.
1. Traditional MoE parallel groups are entangled with dense by using a 5-dimension parallel group generator with default order `tp-cp-ep-dp-pp`. The EP group in MoE is a sub-group of DP in Attention.
2. With MoE Parallel Fodling, we use a parallel group generator with `tp-cp-dp-pp` for Attention, and another with `tp-ep-dp-pp` for MoE. The EPxTP group in MoE is a sub-group of DPxCPxTP in Attention.
2. With MoE Parallel Folding, we use a parallel group generator with `tp-cp-dp-pp` for Attention, and another with `tp-ep-dp-pp` for MoE. The EPxTP group in MoE is a sub-group of DPxCPxTP in Attention.
By setting `--expert-tensor-parallel-size`, we can set MoE-specific TP size.
......@@ -356,6 +390,7 @@ By setting `--expert-tensor-parallel-size`, we can set MoE-specific TP size.
- Dispatcher `allgather` is the default option. It achieves better performance and efficiency when only tensor parallelism is used or when the Top-k value is very large.
- Dispatcher `alltoall` is recommended if expert parallelism is applied.
- Dispatcher `alltoall_seq` is the original implementation of `alltoall` and is retained for potential compatibility risk.
- Dispatcher `flex` is a new dispatcher decouples communication group from model parallelism. Currently, only the DeepEP backend is supported for by setting `--moe-enable-deepep`.
**Enable Communication Overlap**
- Enable `--overlap-param-gather` and `--overlap-grad-reduce` with distributed optimizer.
......@@ -372,6 +407,10 @@ Therefore, there are two recommended ways during the first 200 steps to avoid th
1. Increase the `expert-tensor-parallel-size` and decrease `expert-model-parallel-size` to replace EP with TP in MoELayer, this can prevent the load imbalancing between EP ranks. Since current ETP implementation has some memeory overhead, you can further enable activation recomputation only for MoE Layer by adding `--moe-layer-recompute`.
2. Setting capacity factor to a relatively small number like 1.0 by adding `--moe-token-capacity-factor 1.0`.
**Leverage DeepSeek's DeepEP for High-Performance Cross-Node Token Dispatching**
- The primary advantage of DeepEP is its cross-node token communication efficiency, which delivers substantial performance improvements when deploying expert parallelism across multiple nodes with large TopK values.
- To enable DeepEP in your training configuration, simply set `--moe-token-dispatcher-type=flex` and `--moe-enable-deepep` in your command line arguments.
### Reference Best Parallel Mapping
Here are the reference parallel mappings of MCore v0.8 for Mixtral 8x7B and 8x22B models:
......
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import copy
import itertools
from copy import deepcopy
from functools import partial, wraps
......@@ -33,7 +34,10 @@ from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_object_for_checkpoint
from megatron.core.transformer.utils import (
make_sharded_object_for_checkpoint,
sharded_state_dict_default,
)
try:
......@@ -369,6 +373,7 @@ class GroupedMLP(MegatronModule):
v_tensors = []
w_lens = []
v_lens = []
expert_global_idx = local_expert_indices_offset + local_expert_idx
for input_dim_idx in range(self.config.hidden_size):
for glu_idx in range(2):
local_idx = (
......@@ -399,9 +404,6 @@ class GroupedMLP(MegatronModule):
== local_flattened_range.stop - local_flattened_range.start
)
start_pos += len(local_tensor)
expert_global_idx = (
local_expert_indices_offset + local_expert_idx
)
if glu_idx == 0:
w_tensors.append(local_tensor)
w_lens.append(len(local_tensor))
......@@ -427,7 +429,11 @@ class GroupedMLP(MegatronModule):
),
non_flat_local_shape,
*sharded_offsets,
(prepend_axis_num, expert_global_idx, num_global_experts),
(
prepend_axis_num,
expert_global_idx, # pylint: disable=E0606
num_global_experts,
),
(prepend_axis_num + 1 + tp_axis, tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
......@@ -556,12 +562,36 @@ class GroupedMLP(MegatronModule):
tp_axis = 0
with_glu = False
wkey = f'{prefix}experts.linear_fc2.weight'
"""
When MCore Custom FSDP `optim_grads_params` is enabled, it is necessary to save the tensor local shard.
This local shard is accessible through the `fully_shard_param_local_shard` attribute of the tensor.
This attribute contains the local shard of the fully sharded parameter, which is essential for
correctly saving and loading the model state when using `optim_grads_params` with FSDP.
Example:
>>> # Assuming `tensor` is a fully sharded parameter
>>> local_shard = tensor.fully_shard_param_local_shard
>>> # Save the local shard as needed
"""
this_replica_id = list(copy.deepcopy(replica_id))
if hasattr(tensor, 'fully_shard_param_local_shard'):
if tensor.fully_shard_param_local_shard.numel() == 0:
continue
flattened_range = slice(*tensor.fully_shard_param_local_index)
tensor = tensor.fully_shard_param_local_shard
this_replica_id[-1] = 0
else:
flattened_range = None
sharded_state_dict[f'{prefix}{name}'] = ShardedTensorFactory(
wkey,
tensor,
partial(sh_ten_build_fn, tp_axis=tp_axis, with_glu=with_glu),
partial(sh_ten_merge_fn, tp_axis=tp_axis, with_glu=with_glu),
replica_id,
tuple(this_replica_id),
flattened_range=flattened_range,
)
replica_id = (
......@@ -719,7 +749,7 @@ class TEGroupedMLP(MegatronModule):
"""
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{name}.', sharded_offsets, metadata)
sub_sd = sharded_state_dict_default(module, f'{name}.', sharded_offsets, metadata)
if name == 'linear_fc1' and self.config.gated_linear_unit:
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
......@@ -749,15 +779,20 @@ class SequentialMLP(MegatronModule):
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
if config.moe_ffn_hidden_size == config.ffn_hidden_size:
super().__init__(config=config)
else:
# Local SequentialMLP can still be used here by overriding the ffn_hidden_size
# with a deepcopied config.
sequential_mlp_config = deepcopy(config)
sequential_mlp_config.ffn_hidden_size = config.moe_ffn_hidden_size
super().__init__(config=sequential_mlp_config)
self.add_bias = config.add_bias_linear
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
assert (
self.config.moe_ffn_hidden_size == self.config.ffn_hidden_size
), "Please use GroupedMLP or TEGroupedMLP when moe_ffn_hidden_size is \
different from ffn_hidden_size"
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
......@@ -844,6 +879,12 @@ class SequentialMLP(MegatronModule):
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
is_custom_fsdp_shard_tensor = getattr(sh_ten, "is_data_parallel_fully_shard", False)
if is_custom_fsdp_shard_tensor:
sh_ten.replica_id = (*replica_id[:2], 0)
continue
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_expert_data_parallel_rank(),
......
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Portions of this code are from DeepSeek DeepEP project
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
try:
from deep_ep import Buffer
HAVE_DEEP_EP = True
except ImportError:
HAVE_DEEP_EP = False
import torch
_buffer = None
def get_hidden_bytes(x: torch.Tensor) -> int:
"""Calculate the number of hidden bytes for a tensor.
Args:
x (torch.Tensor): Input tensor
Returns:
int: Number of hidden bytes
"""
return x.size(1) * max(x.element_size(), 2)
def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int):
"""Get or create a buffer for all-to-all communication.
Args:
group (torch.distributed.ProcessGroup): Process group for communication
hidden_bytes (int): Number of hidden bytes needed
Returns:
Buffer: Communication buffer
"""
global _buffer
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (
Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()),
):
# Split long line for PEP8 compliance
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
)
num_rdma_bytes = max(
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
)
# Allocate buffer if not existed or not enough buffer
# NOTES: the adaptive routing configuration of the network **must be off**
if (
_buffer is None
or _buffer.group != group
or _buffer.num_nvl_bytes < num_nvl_bytes
or _buffer.num_rdma_bytes < num_rdma_bytes
):
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer
class FusedDispatch(torch.autograd.Function):
"""Fused dispatch operation for MoE routing combining computation and communication."""
@staticmethod
def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None):
"""Forward pass of fused dispatch."""
# Calculate layout before actual dispatch
buffer = get_buffer(group, get_hidden_bytes(x))
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
previous_event,
) = buffer.get_dispatch_layout(
token_indices,
num_experts,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
)
# Do MoE dispatch
# NOTES: the CPU will wait for GPU's signal to arrive,
# so this is not compatible with CUDA graph
(
recv_x,
recv_token_indices,
recv_token_probs,
num_recv_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(
x,
topk_idx=token_indices,
topk_weights=token_probs.float(),
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
)
ctx.group = group
ctx.handle = handle
ctx.event = event
tokens_per_expert = torch.tensor(num_recv_tokens_per_expert_list)
return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle)
@staticmethod
def backward(
ctx, grad_output, grad_token_indices, grad_token_probs, grad_tokens_per_expert, grad_handle
):
"""Backward pass of fused dispatch."""
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
handle = ctx.handle
grad_x, grad_token_probs, event = buffer.combine(
grad_output.contiguous(),
handle,
topk_weights=grad_token_probs.float(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
)
return grad_x, None, grad_token_probs, None, None, None
class FusedCombine(torch.autograd.Function):
"""Fused combine operation for MoE output combining computation and communication."""
@staticmethod
def forward(ctx, x, group, handle, previous_event=None):
"""Forward pass of fused combine."""
buffer = get_buffer(group, get_hidden_bytes(x))
combined_x, _, event = buffer.combine(
x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False
)
ctx.handle = handle
ctx.group = group
return combined_x, event
@staticmethod
def backward(ctx, grad_output, previous_event=None):
"""Backward pass of fused combine."""
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
grad_x, _, _, _, _, event = buffer.dispatch(
grad_output.contiguous(),
handle=ctx.handle,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
)
return grad_x, None, None, None
if HAVE_DEEP_EP:
def fused_dispatch(x, token_indices, token_probs, num_experts, group, previous_event=None):
"""Perform fused dispatch operation if deep_ep is available.
Args:
x: Input tensor [num_tokens, hidden_size]
token_indices: Token routing indices [num_tokens, topk]
token_probs: Token routing probabilities [num_tokens, topk]
num_experts: Number of experts
group: Process group
previous_event: Previous CUDA event
Returns:
Result of FusedDispatch
"""
return FusedDispatch.apply(
x.contiguous(), token_indices, token_probs, num_experts, group, previous_event
)
def fused_combine(x, group, handle, previous_event=None):
"""Perform fused combine operation if deep_ep is available.
Args:
x: Input tensor
group: Process group
handle: Communication handle
previous_event: Previous CUDA event
Returns:
Result of FusedCombine
"""
return FusedCombine.apply(x, group, handle, previous_event)
else:
fused_dispatch = None
fused_combine = None
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# type: ignore
# This file will be deprecated soon. We won't fix the mypy type checks.
from typing import List, Optional, Tuple
import torch
......@@ -60,13 +63,13 @@ class MoEAlltoAllSEQTokenDispatcher(MoETokenDispatcher):
self.num_global_tokens_per_local_expert_cpu = None
input_chunk_idxs = torch.arange(self.num_experts)
# [num_local_experts, ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = (
input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist()
)
self.sort_input_by_local_experts = input_chunk_idxs.reshape(
-1, self.num_local_experts
).T.ravel()
# [ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = (
input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist()
)
self.restore_output_by_local_experts = input_chunk_idxs.reshape(
self.num_local_experts, -1
).T.ravel()
# Token drop and padding.
# We need to keep track of the token num if we drop tokens without padding them.
......
......@@ -7,17 +7,15 @@ from typing import Union
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
MoEFlexTokenDispatcher,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
......@@ -80,7 +78,7 @@ class MoELayer(BaseMoELayer):
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
self, config: TransformerConfig, submodules: MoESubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
......@@ -89,20 +87,6 @@ class MoELayer(BaseMoELayer):
# Initialize router
self.router = TopKRouter(config=self.config)
# Initialize experts
if self.config.moe_grouped_gemm:
if isinstance(self.submodules.experts, MLPSubmodules):
self.experts = TEGroupedMLP(
self.num_local_experts, self.config, self.submodules.experts
)
else:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules.experts, MLPSubmodules)
self.experts = SequentialMLP(
self.num_local_experts, self.config, self.submodules.experts
)
# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
......@@ -116,14 +100,21 @@ class MoELayer(BaseMoELayer):
self.token_dispatcher = MoEAlltoAllSEQTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "flex":
self.token_dispatcher = MoEFlexTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
# Initialize experts
self.experts = build_module(self.submodules.experts, self.num_local_experts, self.config)
# Initialize shared experts
if self.use_shared_expert:
self.shared_experts = SharedExpertMLP(self.config, self.submodules.shared_experts)
self.shared_experts = build_module(self.submodules.shared_experts, config=self.config)
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)
......@@ -149,7 +140,7 @@ class MoELayer(BaseMoELayer):
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output += self.shared_experts(hidden_states)
output = output + self.shared_experts(hidden_states)
return output, mlp_bias
if self.moe_layer_recompute:
......
......@@ -6,6 +6,18 @@ from typing import Optional
import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
try:
from megatron.core.extensions.transformer_engine import (
fused_permute,
fused_sort_chunks_by_index,
fused_unpermute,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
def switch_load_balancing_loss_func(
......@@ -56,6 +68,59 @@ def switch_load_balancing_loss_func(
return aux_loss
def sequence_load_balancing_loss_func(
probs: torch.Tensor,
routing_map: torch.Tensor,
batch_size: int,
seq_length: int,
topk: int,
moe_aux_loss_coeff: float,
sequence_partition_group=None,
):
"""
Calculate the auxiliary loss in sequence-level by computing the loss for each individual sample.
Refer to the DeepSeek-V2 huggingface repo
(https://huggingface.co/deepseek-ai/DeepSeek-V2) for details.
Args:
probs (torch.Tensor): Softmax probabilities output by the router for each token.
Shape in [num_tokens, num_experts].
routing_map (torch.Tensor): Mapping of tokens to experts assignment.
Shape in [num_tokens, num_experts].
batch_size (int): Batch size to process.
seq_length (int): Sequence length to process.
topk (int): Number of experts to route to for each token.
moe_aux_loss_coeff (float): Scaling coefficient for the auxiliary loss.
sequence_partition_group (optional): The parallel group over which the sequence is
partitioned. If None, no partitioning is applied.
Defaults to None.
Returns:
torch.Tensor: The sequence auxiliary loss for load balancing.
"""
num_sub_sequence = 1
num_experts = probs.shape[1]
probs_for_aux_loss = probs.view(seq_length, batch_size, -1)
routing_map = routing_map.view(seq_length, batch_size, -1)
# If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism
# or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full
# sequence.
if sequence_partition_group is not None:
num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
seq_length *= num_sub_sequence
probs_for_aux_loss = gather_from_sequence_parallel_region(
probs_for_aux_loss, group=sequence_partition_group
)
cost_coeff = routing_map.sum(dim=0, dtype=torch.float).div_(seq_length * topk / num_experts)
seq_aux_loss = (cost_coeff * probs_for_aux_loss.mean(dim=0)).sum(dim=1).mean()
seq_aux_loss *= moe_aux_loss_coeff
return seq_aux_loss
def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
......@@ -108,7 +173,7 @@ def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss."""
"""An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss."""
main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)
......@@ -153,29 +218,60 @@ class MoEAuxLossAutoScaler(torch.autograd.Function):
MoEAuxLossAutoScaler.main_loss_backward_scale = scale
def permute(tokens, routing_map, num_out_tokens: int = None):
def permute(
tokens,
routing_map,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
fused (bool, optional): Whether use the fused permute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
if fused:
if not HAVE_TE or fused_permute is None:
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
return fused_permute(tokens, routing_map, num_out_tokens)
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
sorted_indices = sorted_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
......@@ -189,11 +285,19 @@ def unpermute(
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
When drop_and_pad=True, the tensors will have the following properties:
- In routing_map, the number of non-zeros in each column equals to expert capacity
- The size of sorted_indices equals to num_experts * capacity, each split of `capacity`
contains the indices of tokens routed to an expert.
This function exploits these features to use ops that support cuda graph.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
......@@ -201,15 +305,40 @@ def unpermute(
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
fused (bool, optional): Whether use the fused unpermute function.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
"""
if fused:
if not HAVE_TE or fused_unpermute is None:
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape)
_, hidden = restore_shape
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)
# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
# Create an output tensor filled with zeros
......@@ -221,13 +350,77 @@ def unpermute(
return output_tokens
def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor):
def sort_chunks_by_idxs(
input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False
):
"""Split and sort the input tensor based on the split_sizes and sorted indices."""
if fused:
if not HAVE_TE or fused_sort_chunks_by_index is None:
raise ValueError(
"fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0."
)
return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs)
input = torch.split(input, split_sizes.tolist(), dim=0)
output = torch.cat([input[i] for i in sorted_idxs], dim=0)
output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
return output
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
num_groups: int,
group_topk: int,
):
"""Perform top-k routing on a subset of expert groups.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
Args:
scores (torch.Tensor): Softmax scores generated by the router.
topk (int): The number of experts to select for each token.
num_tokens (int): The number of tokens.
num_experts (int): The number of experts.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of groups selected for each token.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
"""
# Organize the experts into groups
group_scores = scores.view(num_tokens, num_groups, -1).topk(2, dim=-1)[0].sum(dim=-1)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
return probs, top_indices
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
......@@ -235,18 +428,32 @@ def topk_softmax_with_capacity(
pad_to_capacity: bool = False,
drop_policy: str = "probs",
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number
capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
tokens will be 0.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
If "prob", the tokens with the lowest probabilities will be dropped.
If "position", tokens at the end of each batch will be dropped.
use_pre_softmax (bool): Whether to apply softmax before top-k selection.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of selected groups for each token.
scaling_factor (float): Scaling factor of routing score in top-k selection.
deterministic_mode (bool): Deprecated.
score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
expert_bias (torch.Tensor): The bias added to logits for expert routing.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
......@@ -255,23 +462,45 @@ def topk_softmax_with_capacity(
indicating which experts were selected for each token. True values represent
the selected experts.
- tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
the number of local tokens assigned to each expert.
the number of local tokens assigned to each expert before dropping and padding.
"""
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]
if use_pre_softmax:
# Pre softmax
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = torch.topk(scores, k=topk, dim=1)
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
# Post softmax
if topk == 1:
# Requires applying softmax before selecting the top-k when k is 1,
# since softmax on a [num_tokens, 1] would yield a zero gradient.
raise ValueError("Please use --moe-router-pre-softmax when topk is 1.")
scores, top_indices = torch.topk(logits, k=topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
# TODO Try using element-wise operations instead of scatter?
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
......@@ -405,3 +634,23 @@ def track_moe_metrics(
)
clear_aux_losses_tracker()
def get_updated_expert_bias(tokens_per_expert, expert_bias, expert_bias_update_rate):
"""Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1#
Args:
tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert.
expert_bias (torch.Tensor): The bias for each expert.
expert_bias_udpate_rate (float): The update rate for the expert bias.
"""
with torch.no_grad():
# All Reduce Across TPxCPxDP group
torch.distributed.all_reduce(
tokens_per_expert,
group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
)
average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1]
offset = average_tokens - tokens_per_expert
updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate
return updated_expert_bias
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from functools import partial
from typing import Callable
import torch
......@@ -10,6 +12,7 @@ from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sequence_load_balancing_loss_func,
sinkhorn,
switch_load_balancing_loss_func,
topk_softmax_with_capacity,
......@@ -99,8 +102,23 @@ class TopKRouter(Router):
super().__init__(config=config)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.score_function = self.config.moe_router_score_function
self.input_jitter = None
self.enable_expert_bias = self.config.moe_router_enable_expert_bias
if self.enable_expert_bias:
self.register_buffer(
'local_tokens_per_expert',
torch.zeros(self.config.num_moe_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
'expert_bias', torch.zeros(self.config.num_moe_experts, dtype=torch.float32)
)
else:
self.local_tokens_per_expert = None
self.expert_bias = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
......@@ -142,7 +160,7 @@ class TopKRouter(Router):
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
indices (torch.Tensor): The mask of token to experts assignment.
routing_map (torch.Tensor): The mask of token to experts assignment.
"""
probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
logits,
......@@ -151,47 +169,78 @@ class TopKRouter(Router):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
if self.training:
# Apply load balancing loss
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
aux_loss_func = partial(
switch_load_balancing_loss_func,
probs=scores,
tokens_per_expert=tokens_per_expert,
topk=self.topk,
)
probs = self.apply_load_balancing_loss(
activation=probs, load_balancing_loss_func=aux_loss_func
)
return probs, routing_map
def apply_load_balancing_loss(
self,
probs: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length: int):
"""Apply loss-based load balancing to the logits tensor."""
Args:
probs (torch.Tensor): The probs output by the router for each token.
[num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert.
[num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
if self.training:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
aux_loss_func = partial(
sequence_load_balancing_loss_func,
probs=scores,
routing_map=routing_map,
batch_size=bsz,
seq_length=seq_length,
topk=self.topk,
)
probs = self.apply_load_balancing_loss(
activation=probs, load_balancing_loss_func=aux_loss_func
)
return probs, routing_map
def apply_load_balancing_loss(
self, activation: torch.Tensor, load_balancing_loss_func: Callable
):
"""Calculate auxiliary loss, attach gradient function to activation and add to logging."""
moe_aux_loss_coeff = self.config.moe_aux_loss_coeff
if moe_aux_loss_coeff == 0:
return activation
sequence_partition_group = None
if self.config.moe_token_dispatcher_type == "alltoall_seq":
sequence_partition_group = parallel_state.get_context_parallel_group()
moe_aux_loss_coeff /= parallel_state.get_tensor_model_parallel_world_size()
else:
elif parallel_state.get_tensor_and_context_parallel_world_size() > 1:
sequence_partition_group = parallel_state.get_tensor_and_context_parallel_group()
aux_loss = switch_load_balancing_loss_func(
probs,
num_local_tokens_per_expert,
self.topk,
moe_aux_loss_coeff,
sequence_partition_group=sequence_partition_group,
aux_loss = load_balancing_loss_func(
moe_aux_loss_coeff=moe_aux_loss_coeff, sequence_partition_group=sequence_partition_group
)
save_to_aux_losses_tracker(
"load_balancing_loss",
......@@ -257,6 +306,7 @@ class TopKRouter(Router):
routing_map (torch.Tensor): The mapping of token to experts assignment,
with shape [num_tokens, num_experts].
"""
seq_length, bsz = logits.shape[:2]
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss
......@@ -270,6 +320,8 @@ class TopKRouter(Router):
scores, routing_map = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, routing_map = self.aux_loss_load_balancing(logits)
elif self.routing_type == "seq_aux_loss":
scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
scores, routing_map, _ = topk_softmax_with_capacity(
......@@ -279,10 +331,19 @@ class TopKRouter(Router):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
# Prevent extra local tokens accumulation on evaluation or activation recomputation
if self.enable_expert_bias and torch.is_grad_enabled():
with torch.no_grad():
self.local_tokens_per_expert += routing_map.sum(dim=0)
return scores, routing_map
......@@ -293,12 +354,10 @@ class TopKRouter(Router):
Args:
input (torch.Tensor): Input tensor.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, routing_map = self.routing(logits)
......
......@@ -17,8 +17,7 @@ from megatron.core.tensor_parallel.mappings import (
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_torch_min_version, make_sharded_tensor_for_checkpoint
......@@ -32,15 +31,15 @@ class SharedExpertMLP(MLP):
# The shared experts are scheduled into this stream to be overlapped with the dispatcher.
stream = None
def __init__(self, config: TransformerConfig, spec: ModuleSpec):
def __init__(self, config: TransformerConfig, submodules: MLPSubmodules, gate: bool):
config = deepcopy(config)
assert config.add_bias_linear == False, "bias is not supported in the shared experts, "
"please set '--disable-bias-linear' instead."
config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
super().__init__(config=config, submodules=spec.submodules)
super().__init__(config=config, submodules=submodules)
self.use_shared_expert_gate = spec.params.get("gate", False)
self.use_shared_expert_gate = gate
if self.use_shared_expert_gate:
# TODO: Add support for GPU initialization, which requires updating the golden values.
self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
import torch
......@@ -16,6 +16,7 @@ from megatron.core.tensor_parallel import (
gather_from_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.moe.fused_a2a import fused_combine, fused_dispatch
from megatron.core.transformer.moe.moe_utils import (
get_capacity,
permute,
......@@ -102,6 +103,7 @@ class MoETokenDispatcher:
def set_shared_experts(self, shared_experts):
"""Set shared expert to the dispatcher."""
assert self.config.moe_shared_expert_overlap
self.shared_experts = shared_experts
......@@ -125,9 +127,6 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
self.router_topk = config.moe_router_topk
self.add_bias = config.add_bias_linear
# self.local_probs: probs of global token assignment to local experts.
self.local_probs = None
# self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where
# each element is True if it's between the local_expert_indices. Only useful when cross
# device token permutation is enabled and **AllGahter** is performed.
......@@ -183,6 +182,7 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
self.local_map = routing_map[
:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
].contiguous()
# probs of global token assignment to local experts.
self.local_probs = probs[
:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
].contiguous()
......@@ -190,7 +190,10 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
tokens_per_expert = self.local_map.sum(dim=0).long().cpu()
(permuted_local_hidden_states, self.reversed_local_input_permutation_mapping) = permute(
hidden_states, self.local_map
hidden_states,
self.local_map,
num_out_tokens=tokens_per_expert.sum(),
fused=self.config.moe_permute_fusion,
)
return permuted_local_hidden_states, tokens_per_expert
......@@ -220,6 +223,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
hidden_states,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.local_map,
fused=self.config.moe_permute_fusion,
)
unpermuted_local_bias = None
......@@ -230,6 +235,8 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
bias,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.local_map,
fused=self.config.moe_permute_fusion,
)
output_total = unpermuted_local_hidden
......@@ -279,8 +286,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
config (TransformerConfig): Configuration for the transformer model.
"""
super().__init__(config=config)
self.hidden_shape = None
self.num_local_experts = num_local_experts
assert config.num_moe_experts is not None
self.num_experts = config.num_moe_experts
assert self.num_local_experts > 0, "Expected at least one expert"
self.local_expert_indices = local_expert_indices
......@@ -291,7 +298,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
assert (
self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1
), "local_expert_indices must be continous"
self.probs = None
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
......@@ -302,26 +308,25 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks.
self.output_splits_tp = None
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert_cpu = None
input_chunk_idxs = torch.arange(self.num_experts * self.tp_size)
# [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = (
input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel().tolist()
self.permute_idx_device = torch.device("cuda") if self.config.moe_permute_fusion else None
input_chunk_idxs = torch.arange(
self.num_experts * self.tp_size, device=self.permute_idx_device
)
# [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts.
self.sort_input_by_local_experts = input_chunk_idxs.reshape(
-1, self.num_local_experts
).T.ravel()
# [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts.
self.restore_output_by_local_experts = (
input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel().tolist()
)
self.restore_output_by_local_experts = input_chunk_idxs.reshape(
self.num_local_experts, -1
).T.ravel()
# Token drop and padding.
# We need to keep track of the token num if we drop tokens without padding them.
self.num_out_tokens = None
# Drop and pad the input to capacity.
self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity
if self.drop_and_pad:
assert self.config.moe_expert_capacity_factor is not None
self.moe_expert_capacity_factor = self.config.moe_expert_capacity_factor
self.capacity = None
# A cuda stream synchronization is needed in self.token_permutation() in some cases,
......@@ -357,7 +362,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.capacity = get_capacity(
num_tokens=num_tokens,
num_experts=self.num_experts,
capacity_factor=self.config.moe_expert_capacity_factor,
capacity_factor=self.moe_expert_capacity_factor,
)
self.num_out_tokens = self.capacity * self.num_experts
# [num_local_experts], number of tokens processed by each expert.
......@@ -366,9 +371,13 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.capacity * self.tp_size * self.ep_size,
dtype=torch.long,
)
# [tp_size * ep_size, num_local_experts].
self.num_global_tokens_per_local_expert_cpu = torch.full(
(self.num_experts * self.tp_size,), self.capacity, dtype=torch.long
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = torch.full(
(self.num_experts * self.tp_size,),
self.capacity,
dtype=torch.long,
device=self.permute_idx_device,
)
return num_tokens_per_local_expert
elif self.config.moe_expert_capacity_factor is not None:
......@@ -395,6 +404,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self.input_splits = (
num_local_tokens_per_expert.reshape(self.ep_size, self.num_local_experts)
.sum(axis=1)
......@@ -447,9 +458,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
if self.num_local_experts > 1:
self.num_global_tokens_per_local_expert_cpu = num_global_tokens_per_local_expert.view(
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(
-1, self.num_local_experts
).to(torch.device("cpu"), non_blocking=True)
)
if not self.config.moe_permute_fusion:
self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.to(
torch.device("cpu"), non_blocking=False
)
return num_tokens_per_local_expert
......@@ -493,7 +510,11 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if self.cuda_sync_point == "before_permutation_1":
torch.cuda.current_stream().synchronize()
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states, routing_map, num_out_tokens=self.num_out_tokens
hidden_states,
routing_map,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
# Perform expert parallel AlltoAll communication
......@@ -506,21 +527,35 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
if self.tp_size > 1:
if self.output_splits_tp is None:
output_split_sizes = None
else:
output_split_sizes = self.output_splits_tp.tolist()
global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens,
group=self.tp_group,
output_split_sizes=(
self.output_splits_tp.tolist() if self.output_splits_tp is not None else None
),
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
if self.num_local_experts > 1:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert_cpu.ravel(),
self.sort_input_by_local_experts,
)
if self.drop_and_pad:
global_input_tokens = (
global_input_tokens.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_input_tokens.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts,
fused=self.config.moe_permute_fusion,
)
if self.cuda_sync_point == "before_finish":
torch.cuda.current_stream().synchronize()
......@@ -551,19 +586,33 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert_cpu.T.ravel(),
self.restore_output_by_local_experts,
)
if self.drop_and_pad:
hidden_states = (
hidden_states.view(
self.num_local_experts,
self.tp_size * self.ep_size,
self.capacity,
*hidden_states.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert.T.ravel(),
self.restore_output_by_local_experts,
fused=self.config.moe_permute_fusion,
)
if self.tp_size > 1:
if self.output_splits_tp is None:
input_split_sizes = None
else:
input_split_sizes = self.output_splits_tp.tolist()
hidden_states = reduce_scatter_to_sequence_parallel_region(
hidden_states,
group=self.tp_group,
input_split_sizes=(
self.output_splits_tp.tolist() if self.output_splits_tp is not None else None
),
hidden_states, group=self.tp_group, input_split_sizes=input_split_sizes
)
# Perform expert parallel AlltoAll communication
......@@ -582,6 +631,8 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
# Reshape the output tensor
......@@ -592,3 +643,280 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
return output, None
class _DispatchManager(ABC):
"""
A manager class to handle dispatch and combine processes for MoE models.
DispatcherManager handles token dispatching according to the routing_map of format
[num_local_tokens, world_size, num_instances]. The routing_map is a 3D tensor where each
element indicates whether a token should be sent to a specific rank.
num_instances is the maximum number of tokens instances dispatched into a target rank, it
can be the number of local experts, or the size of sub_group.
"""
@abstractmethod
def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor):
"""Set up metadata of routing_map and probs."""
pass
@abstractmethod
def dispatch(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Dispatch the hidden_states according to the routing_map."""
pass
@abstractmethod
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Combine the hidden_states after expert processing."""
pass
@abstractmethod
def get_dispached_metadata(self) -> torch.Tensor:
"""Get the metadata of the dispatched hidden_states."""
pass
@abstractmethod
def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Get the permuted hidden states by instances."""
pass
@abstractmethod
def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Get the restored hidden states by instances."""
pass
class _DeepepManager(_DispatchManager):
"""
A manager class to handle fused all-to-all communication processes for MoE models using
DeepEP backend. See https://github.com/deepseek-ai/deepep for more details.
The workflow of the DeepEP dispatcher is:
(1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata
(2) dispatch():
- Use fused kernel to permute tokens and perform all-to-all communication in single step
(3) get_permuted_hidden_states_by_instances():
- Convert routing map and probabilities to multihot format
- Permute tokens using fused kernel
(4) get_restored_hidden_states_by_instances():
- Reverse permutation using fused kernel
(5) combine():
- Reverse process using fused kernel to unpermute and perform all-to-all in single step
This implementation uses fused communication kernels (fused_dispatch/fused_combine) that
combine permutation and communication operations for improved efficiency compared to
separate permute+alltoall steps.
"""
def __init__(
self,
group: torch.distributed.ProcessGroup,
router_topk: int,
permute_fusion: bool = False,
capacity_factor: float = None,
num_experts: int = None,
num_local_experts: int = None,
):
self.group = group
self.router_topk = router_topk
self.capacity_factor = capacity_factor
self.permute_fusion = permute_fusion
self.num_experts = num_experts
self.num_local_experts = num_local_experts
# Metadata
self.token_indices = None
self.token_probs = None
# Handle used for combine operation
self.handle = None
if fused_dispatch is None:
raise ImportError(
"DeepEP is not installed. Please install DeepEP package from "
"https://github.com/deepseek-ai/deepep."
)
def setup_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor):
num_tokens = routing_map.shape[0]
routing_map = routing_map.reshape(num_tokens, self.num_experts)
probs = probs.reshape(num_tokens, self.num_experts)
# Convert the format of routing map from multihot to indices.
self.token_probs, self.token_indices = torch.topk(probs, self.router_topk, dim=-1)
# Mask the indices of dropped tokens with -1
if self.capacity_factor is not None:
mask = self.token_probs == 0
self.token_indices = self.token_indices.masked_fill(mask, -1)
def dispatch(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, dispatched_indices, dispatched_probs, num_tokens_per_expert, handle = (
fused_dispatch(
hidden_states, self.token_indices, self.token_probs, self.num_experts, self.group
)
)
self.handle = handle
self.tokens_per_expert = num_tokens_per_expert
self.dispatched_indices = dispatched_indices
self.dispatched_probs = dispatched_probs
return hidden_states
def _indices_to_multihot(self, indices, probs):
"""
Converts a tensor of indices to a multihot vector.
Args:
indices (torch.Tensor): [num_tokens, topk] token indices, where -1 means masked out.
probs (torch.Tensor): [num_tokens, topk] token probabilities.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- routing_map: Multihot vector.
- probs: Multihot probabilities.
"""
batch_size = indices.shape[0]
multihot_routing_map = torch.zeros(
(batch_size, self.num_local_experts), dtype=torch.long, device=indices.device
)
multihot_probs = torch.zeros(
(batch_size, self.num_local_experts), dtype=torch.float, device=indices.device
)
mask = indices != -1
valid_indices = indices[mask]
row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave(
mask.sum(dim=1)
)
multihot_routing_map[row_indices, valid_indices] = 1
multihot_probs[row_indices, valid_indices] = probs[mask]
return multihot_routing_map.bool(), multihot_probs
def get_dispached_metadata(self) -> torch.Tensor:
return self.dispatched_indices, self.dispatched_probs
def get_number_of_tokens_per_expert(self) -> torch.Tensor:
"""
Get the number of tokens per expert.
"""
return self.tokens_per_expert
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, event = fused_combine(hidden_states, self.group, self.handle)
# Release the handle after combine operation
self.handle = None
return hidden_states
def get_permuted_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.dispatched_routing_map, self.dispatched_probs = self._indices_to_multihot(
self.dispatched_indices, self.dispatched_probs
)
self.hidden_shape_before_permute = hidden_states.shape
hidden_states, self.reversed_mapping_for_combine = permute(
hidden_states,
self.dispatched_routing_map,
num_out_tokens=sum(self.tokens_per_expert),
fused=self.permute_fusion,
)
return hidden_states
def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
assert self.dispatched_probs.dtype == torch.float32, "DeepEP only supports float32 probs"
hidden_states = unpermute(
hidden_states,
self.reversed_mapping_for_combine,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.dispatched_routing_map,
probs=self.dispatched_probs,
fused=self.permute_fusion,
)
return hidden_states.to(input_dtype)
class MoEFlexTokenDispatcher(MoETokenDispatcher):
"""
Flexible token dispatcher for MoE models with Efficient-A2A communication kernels.
"""
def __init__(
self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig
):
super().__init__(config)
self.num_local_experts = num_local_experts
self.local_expert_indices = local_expert_indices
assert self.tp_size * self.ep_size > 1, "Flex token dispatcher requires TPxEP > 1"
assert (
self.config.moe_enable_deepep
), "DeepEP is not enabled. Please set --moe-enable-deepep to use DeepEP backend."
assert (
self.config.moe_pad_expert_input_to_capacity is False
), "Flex token dispatcher does not support --moe-pad-expert-input-to-capacity"
self._comm_manager = _DeepepManager(
group=self.tp_ep_group,
router_topk=self.tp_size * self.config.moe_router_topk,
permute_fusion=self.config.moe_permute_fusion,
capacity_factor=self.config.moe_expert_capacity_factor,
num_experts=self.tp_size * self.config.num_moe_experts,
num_local_experts=self.num_local_experts,
)
def set_shared_experts(self, shared_experts):
raise NotImplementedError("Shared experts overlap not supported in flex token dispatcher")
def _initialize_metadata(self, routing_map: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
"""
Initialize the routing map and probs to a unified format covering the TPxEP group.
This design decouples the communication group from underlying model parallelism groups,
such that the communication strategy of tokens can be agnostic of TP size and EP size.
This function expands the routing_map from shape [num_local_tokens, num_experts] to
[num_local_tokens, world_size, num_local_experts]. Each element in the routing_map
indicates whether a token should be sent to a specific rank. Specifically, the
routing_map is replicated across TP group since each TP ranks in a TP group should
receive the same tokens.
"""
num_local_tokens = routing_map.shape[0]
world_size = self.tp_size * self.ep_size
# Organize routing map and probs to [num_local_tokens, world_size, num_local_experts]
routing_map = (
routing_map.reshape(num_local_tokens, self.ep_size, 1, self.num_local_experts)
.expand(-1, -1, self.tp_size, -1)
.reshape(num_local_tokens, world_size, self.num_local_experts)
).contiguous()
probs = (
probs.reshape(num_local_tokens, self.ep_size, 1, self.num_local_experts)
.expand(-1, -1, self.tp_size, -1)
.reshape(num_local_tokens, world_size, self.num_local_experts)
).contiguous()
return routing_map, probs
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self.hidden_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Initialize metadata
routing_map, probs = self._initialize_metadata(routing_map, probs)
self._comm_manager.setup_metadata(routing_map, probs)
hidden_states = self._comm_manager.dispatch(hidden_states)
global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(
hidden_states
)
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()
return global_input_tokens, tokens_per_expert
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher"
hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states)
hidden_states = self._comm_manager.combine(hidden_states)
return hidden_states.view(self.hidden_shape), None
File mode changed from 100755 to 100644
......@@ -7,12 +7,16 @@ from typing import Union
import torch
from megatron.core import parallel_state
from megatron.core.models.common.embeddings import (
RotaryEmbedding,
YarnRotaryEmbedding,
_yarn_get_mscale,
apply_rotary_pos_emb,
)
from megatron.core.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.attention import Attention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
......@@ -50,11 +54,6 @@ class MultiLatentAttention(Attention):
attention_type: str,
cp_comm_type: str = None,
) -> None:
world_size = parallel_state.get_tensor_model_parallel_world_size()
assert (
world_size == 1
), "MLA is not supported with Tensor Parallelism yet, \
use Expert Parallelism and Pipeline Parallelism for better performance."
super().__init__(
config=config,
......@@ -68,19 +67,35 @@ class MultiLatentAttention(Attention):
self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
# Overwrite the base class kv shape to support MLA inference
self.key_hidden_size = self.q_head_dim
self.val_hidden_size = self.config.v_head_dim
mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
self.rotary_pos_emb = YarnRotaryEmbedding(
self.config.qk_pos_emb_head_dim,
rotary_base=self.config.rotary_base,
scaling_factor=self.config.rotary_scaling_factor,
original_max_position_embeddings=self.config.max_position_embeddings,
beta_fast=self.config.beta_fast,
beta_slow=self.config.beta_slow,
mscale=self.config.mscale,
mscale_all_dim=self.config.mscale_all_dim,
)
if self.config.rope_type == "rope":
self.rotary_pos_emb = RotaryEmbedding(
self.config.qk_pos_emb_head_dim,
rotary_percent=self.config.rotary_percent,
rotary_base=self.config.rotary_base,
)
elif self.config.rope_type == "yarn":
self.rotary_pos_emb = YarnRotaryEmbedding(
self.config.qk_pos_emb_head_dim,
rotary_base=self.config.rotary_base,
scaling_factor=self.config.rotary_scaling_factor,
original_max_position_embeddings=self.config.max_position_embeddings,
beta_fast=self.config.beta_fast,
beta_slow=self.config.beta_slow,
mscale=self.config.mscale,
mscale_all_dim=self.config.mscale_all_dim,
)
else:
raise ValueError(
f"Unsupported RoPE type: {self.config.rope_type}, supported types are "
"'rope' and 'yarn'"
)
self.core_attention = build_module(
submodules.core_attention,
......@@ -120,6 +135,7 @@ class MultiLatentAttention(Attention):
attention_bias=None,
packed_seq_params=None,
position_ids=None,
sequence_len_offset=None,
):
"""Forward pass for multi-latent attention"""
assert rotary_pos_emb is None, "Rotary position embeddings should not be passed into MLA."
......@@ -230,9 +246,9 @@ class MLASelfAttention(MultiLatentAttention):
self.config.q_lora_rank,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
gather_output=False,
is_expert=False,
)
......@@ -254,9 +270,9 @@ class MLASelfAttention(MultiLatentAttention):
self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
gather_output=False,
is_expert=False,
)
......@@ -303,16 +319,19 @@ class MLASelfAttention(MultiLatentAttention):
assert (
hidden_states.ndim == 3
), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D"
q_len, bsz, _ = hidden_states.size()
if self.config.q_lora_rank is not None:
q_compressed, _ = self.linear_q_down_proj(hidden_states)
q_compressed = self.q_layernorm(q_compressed)
q, _ = self.linear_q_up_proj(q_compressed)
q_compressed = gather_from_tensor_model_parallel_region(q_compressed)
if self.config.sequence_parallel:
q_compressed = scatter_to_sequence_parallel_region(q_compressed)
q, _ = self.linear_q_up_proj(self.q_layernorm(q_compressed))
else:
# hidden_states:[s, b, 2048], q: [s, b, n * 192]
q, _ = self.linear_q_proj(hidden_states)
q_len, bsz, _ = q.size()
# q: [s, b, n, 192]
q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)
......@@ -323,12 +342,15 @@ class MLASelfAttention(MultiLatentAttention):
# kv_combined: [s, b, 576]
kv_combined, _ = self.linear_kv_down_proj(hidden_states)
kv_combined = gather_from_tensor_model_parallel_region(kv_combined)
# kv_compressed:[s, b, 512], k_pos_emb: [s, b, 64]
kv_compressed, k_pos_emb = torch.split(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
)
if self.config.sequence_parallel:
kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)
# kv: [s, b, 2048]
kv, _ = self.linear_kv_up_proj(self.kv_layernorm(kv_compressed))
......@@ -343,18 +365,29 @@ class MLASelfAttention(MultiLatentAttention):
# k_no_pe: [s, b, n, 128], value: [s, b, n, 128]
k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)
# rotary_pos_emb:[s, b, 1, 64]
rotary_pos_emb = self.rotary_pos_emb(max_seq_len=self.config.max_position_embeddings)
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, None, hidden_states, self.config, packed_seq_params
)
if len(rotary_pos_emb) == 2:
mscale = rotary_pos_emb[1]
rotary_pos_emb = rotary_pos_emb[0]
# rotary_pos_emb:[s, b, 1, 64]
mscale = 1.0
if self.config.rope_type == "rope":
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
else:
rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len)
if inference_params is not None:
# add offset to the sequence start for inference
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + q_len
rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]
else:
# Shorten rotary_pos_emb to the seuqence length when inference_params
# is not provided. This makes sure we can run forward directly with
# any sequence length. During training, the sequence length is always
# the full rotary_pos_emb length.
rotary_pos_emb = rotary_pos_emb[0:q_len]
# [s, b, 64] -> [s, b, 1, 64]
k_pos_emb = torch.unsqueeze(k_pos_emb, 2)
......@@ -377,7 +410,7 @@ class MLASelfAttention(MultiLatentAttention):
query = torch.cat([q_no_pe, q_pos_emb], dim=-1)
# key: [s, b, n, 192]
k_pos_emb = k_pos_emb.expand(-1, -1, self.config.num_attention_heads, -1)
k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)
query = query.contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment