"src/vscode:/vscode.git/clone" did not exist on "3fef5d27d32ef2cec0871e6676f5c09c7e91fe02"
Unverified Commit e98afbe0 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support dispatching logical to physical experts (#6385)

parent 69af3ec3
......@@ -6,6 +6,7 @@ from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
try:
......@@ -237,6 +238,9 @@ class EPMoE(torch.nn.Module):
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
......
......@@ -22,6 +22,10 @@ from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
......@@ -100,6 +104,7 @@ def grouped_topk(
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
......@@ -140,6 +145,7 @@ def grouped_topk(
topk_weights = topk_weights / topk_weights_sum
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
......@@ -155,6 +161,7 @@ def biased_grouped_topk_impl(
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
......@@ -202,6 +209,7 @@ def biased_grouped_topk_impl(
topk_weights = topk_weights / topk_weights_sum
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
......@@ -232,6 +240,7 @@ def biased_grouped_topk(
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert (
routed_scaling_factor is not None
......@@ -252,6 +261,8 @@ def biased_grouped_topk(
n_share_experts_fusion,
routed_scaling_factor,
)
# TODO merge into kernel for this branch
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
# TODO will fuse this into kernel, thus use slow manual operation now
torch.compile(
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
......@@ -276,6 +287,7 @@ def biased_grouped_topk(
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
......@@ -292,6 +304,7 @@ def select_experts(
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeepSeek V2/V3/R1 series models use grouped_top_k
......@@ -309,6 +322,7 @@ def select_experts(
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
......@@ -322,11 +336,13 @@ def select_experts(
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
elif torch_native and custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
......@@ -337,6 +353,7 @@ def select_experts(
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk"
assert expert_location_dispatch_info is None
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
......@@ -347,6 +364,7 @@ def select_experts(
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function"
assert expert_location_dispatch_info is None
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
......
......@@ -23,9 +23,10 @@ import torch
import torch.distributed
from sglang.srt.managers.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable
from sglang.srt.utils import Withable, get_bool_env_var
logger = logging.getLogger(__name__)
......
......@@ -33,6 +33,7 @@ class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
......@@ -67,9 +68,11 @@ class ExpertLocationMetadata:
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
num_layers_3, num_logical_experts_2 = (
self.logical_to_rank_dispatch_physical_map.shape
)
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
......@@ -196,6 +199,13 @@ class ExpertLocationMetadata:
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
ep_rank=torch.distributed.get_rank(),
),
)
# -------------------------------- usage ------------------------------------
......@@ -262,6 +272,51 @@ def _pad_nested_array(arr, pad_value):
return padded
# TODO use more sophisticated approaches
def compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map: torch.Tensor,
logical_to_all_physical_map_num_valid: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
ep_rank: int,
base_seed: int = 42,
):
device = logical_to_all_physical_map.device
num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
g = torch.Generator(device=device)
g.manual_seed(base_seed + ep_rank)
output_shape = (num_layers, num_logical_experts)
chosen_index = (
torch.randint(
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
)
% logical_to_all_physical_map_num_valid
)
logical_to_rank_dispatch_physical_map = torch.gather(
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
).squeeze(-1)
assert logical_to_rank_dispatch_physical_map.shape == output_shape
for index in range(logical_to_all_physical_map_num_valid.max().item()):
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
is_valid = partial_logical_to_all_physical_map != -1
is_same_gpu = (
partial_logical_to_all_physical_map // num_local_physical_experts
) == ep_rank
logical_to_rank_dispatch_physical_map = torch.where(
is_valid & is_same_gpu,
partial_logical_to_all_physical_map,
logical_to_rank_dispatch_physical_map,
)
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
return logical_to_rank_dispatch_physical_map
@dataclass
class ModelConfigForExpertLocation:
num_layers: int
......
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
@dataclass
class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
# (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,)
partial_logical_to_all_physical_map_num_valid: torch.Tensor
num_physical_experts: int
@classmethod
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
expert_location_metadata = get_global_expert_location_metadata()
if ep_dispatch_algorithm is None:
return None
return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
],
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, :
],
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
layer_id, :
],
num_physical_experts=expert_location_metadata.num_physical_experts,
)
def topk_ids_logical_to_physical(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
if info is None:
return topk_ids
if info.ep_dispatch_algorithm == "static":
return _topk_ids_logical_to_physical_static(topk_ids, info)
if info.ep_dispatch_algorithm == "dynamic":
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
raise NotImplementedError
def _topk_ids_logical_to_physical_static(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
def _topk_ids_logical_to_physical_dynamic(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
topk_ids_original_shape = topk_ids.shape
device = topk_ids.device
topk_ids = topk_ids.flatten()
chosen_dispatch_index = (
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
)
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
topk_ids = topk_ids.view(topk_ids_original_shape)
return topk_ids
......@@ -83,6 +83,7 @@ global_server_args_dict = {
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
......
......@@ -13,7 +13,6 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import inspect
......@@ -196,6 +195,7 @@ class ModelRunner:
"deepep_config": server_args.deepep_config,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"torchao_config": server_args.torchao_config,
......
......@@ -80,6 +80,7 @@ from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -113,6 +114,7 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
logger = logging.getLogger(__name__)
......
......@@ -170,6 +170,7 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
init_expert_location: str = "trivial"
expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"]
......@@ -1271,6 +1272,12 @@ class ServerArgs:
default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
parser.add_argument(
"--ep-dispatch-algorithm",
type=str,
default=ServerArgs.ep_dispatch_algorithm,
help="The algorithm to choose ranks for redundant experts in expert parallel.",
)
parser.add_argument(
"--init-expert-location",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment