Unverified Commit bfc3b3f7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)

parent da5bde4d
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
...@@ -5,12 +21,12 @@ import torch ...@@ -5,12 +21,12 @@ import torch
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import is_sbo_enabled from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var from sglang.srt.utils import get_int_env_var
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
class SboFlags: class SboFlags:
...@@ -54,23 +70,22 @@ class DownGemmOverlapArgs: ...@@ -54,23 +70,22 @@ class DownGemmOverlapArgs:
def execute_sbo( def execute_sbo(
forward_shared_experts: Callable[[], Any], forward_shared_experts: Callable[[], Any],
experts: "DeepEPMoE", experts: FusedMoE,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_output: TopKOutput,
topk_weights: torch.Tensor, alt_stream: Optional[torch.cuda.Stream] = None,
forward_batch: ForwardBatch,
alt_stream: Optional = None,
disable_sbo: bool = False, disable_sbo: bool = False,
): ):
dispatch_output = experts.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch dispatch_output = experts.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
) )
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = ( combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo) _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
) )
hidden_states = experts.moe_impl( hidden_states = experts.run_moe_core(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
) )
if (e := meta_overlap_args.get("record_event_after_down")) is not None: if (e := meta_overlap_args.get("record_event_after_down")) is not None:
...@@ -83,11 +98,10 @@ def execute_sbo( ...@@ -83,11 +98,10 @@ def execute_sbo(
): ):
forward_shared_experts() forward_shared_experts()
hidden_states = experts.combine( hidden_states = experts.dispatcher.combine(
hidden_states, hidden_states=hidden_states,
dispatch_output.topk_idx, topk_ids=dispatch_output.topk_ids,
dispatch_output.topk_weights, topk_weights=dispatch_output.topk_weights,
forward_batch,
overlap_args=combine_overlap_args, overlap_args=combine_overlap_args,
) )
...@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo): ...@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
): ):
return None, None, {} return None, None, {}
hidden_states = dispatch_output.hidden_states_fp8 hidden_states = dispatch_output.hidden_states
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
......
...@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( ...@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
get_global_graph_memory_pool, get_global_graph_memory_pool,
model_capture_mode, model_capture_mode,
set_global_graph_memory_pool, set_global_graph_memory_pool,
set_is_extend_in_batch,
set_torch_compile_config, set_torch_compile_config,
) )
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
...@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
......
...@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( ...@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
get_global_graph_memory_pool, get_global_graph_memory_pool,
model_capture_mode, model_capture_mode,
set_global_graph_memory_pool, set_global_graph_memory_pool,
set_is_extend_in_batch,
set_torch_compile_config, set_torch_compile_config,
) )
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
...@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# Clean intermediate result cache for DP attention # Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens) set_dp_buffer_len(global_dp_buffer_len, num_tokens)
set_is_extend_in_batch(False)
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
......
...@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher: ...@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
def combine_b(self, **kwargs): def combine_b(self, **kwargs):
return self._execute("combine_b", **kwargs) return self._execute("combine_b", **kwargs)
def set_quant_config(self, quant_config: dict):
for inner in self._inners:
inner.set_quant_config(quant_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment