Unverified Commit bfc3b3f7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)

parent da5bde4d
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional
......@@ -5,12 +21,12 @@ import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
class SboFlags:
......@@ -54,23 +70,22 @@ class DownGemmOverlapArgs:
def execute_sbo(
forward_shared_experts: Callable[[], Any],
experts: "DeepEPMoE",
experts: FusedMoE,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
alt_stream: Optional = None,
topk_output: TopKOutput,
alt_stream: Optional[torch.cuda.Stream] = None,
disable_sbo: bool = False,
):
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
dispatch_output = experts.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
)
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
)
hidden_states = experts.moe_impl(
hidden_states = experts.run_moe_core(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
......@@ -83,11 +98,10 @@ def execute_sbo(
):
forward_shared_experts()
hidden_states = experts.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
hidden_states = experts.dispatcher.combine(
hidden_states=hidden_states,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
overlap_args=combine_overlap_args,
)
......@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
):
return None, None, {}
hidden_states = dispatch_output.hidden_states_fp8
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
hidden_states = dispatch_output.hidden_states
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
......
......@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
set_is_extend_in_batch,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
......@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
......
......@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
set_is_extend_in_batch,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
......@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
......
......@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
def combine_b(self, **kwargs):
return self._execute("combine_b", **kwargs)
def set_quant_config(self, quant_config: dict):
for inner in self._inners:
inner.set_quant_config(quant_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment