Unverified Commit 64994980 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[10/N] MoE Refactor: reorganize deepgemm runner in DeepEPMoE (#12054)

parent 729b2429
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import torch import torch
...@@ -13,29 +13,23 @@ from sglang.srt.layers.moe import ( ...@@ -13,29 +13,23 @@ from sglang.srt.layers.moe import (
get_moe_runner_backend, get_moe_runner_backend,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPLLCombineInput,
DeepEPNormalCombineInput,
)
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu from sglang.srt.utils import get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLOutput, DeepEPLLDispatchOutput,
DeepEPNormalOutput, DeepEPNormalDispatchOutput,
DispatchOutput, DispatchOutput,
) )
...@@ -45,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz() ...@@ -45,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip): if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul pass
if _use_aiter: if _use_aiter:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
...@@ -90,6 +84,18 @@ class DeepEPMoE(FusedMoE): ...@@ -90,6 +84,18 @@ class DeepEPMoE(FusedMoE):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if _use_aiter or _is_npu:
self.deprecate_flag = False
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and isinstance(
quant_config, Fp8Config
):
self.deprecate_flag = True
else:
self.deprecate_flag = False
if self.deprecate_flag:
return
if isinstance(quant_config, Fp8Config): if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.use_fp8_w8a8 = True self.use_fp8_w8a8 = True
...@@ -152,6 +158,14 @@ class DeepEPMoE(FusedMoE): ...@@ -152,6 +158,14 @@ class DeepEPMoE(FusedMoE):
disable_sbo=False, disable_sbo=False,
): ):
if self.deprecate_flag:
assert forward_shared_experts is None
assert alt_stream is None
return super().forward(
hidden_states,
topk_output,
)
# We have to call SBO inside MoE to be compatible with hooks used in offloading # We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo( return single_batch_overlap.execute_sbo(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -178,37 +192,51 @@ class DeepEPMoE(FusedMoE): ...@@ -178,37 +192,51 @@ class DeepEPMoE(FusedMoE):
dispatch_output: DispatchOutput, dispatch_output: DispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None, down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
): ):
if self.deprecate_flag:
assert down_gemm_overlap_args is None
return super().run_moe_core(
dispatch_output,
)
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter: if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output) assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output) output = self.forward_aiter(dispatch_output)
if _is_npu: elif _is_npu:
assert DispatchOutputChecker.format_is_deepep(dispatch_output) assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output) output = self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8: if self.use_w4afp8:
return self.forward_cutlass_w4afp8(dispatch_output) output = self.forward_cutlass_w4afp8(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 else:
return self.forward_deepgemm_contiguous(dispatch_output) assert False, "forward_deepgemm_contiguous is deprecated"
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if ( if (
get_moe_runner_backend().is_flashinfer_cutedsl() get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4" and self.quant_config.get_name() == "modelopt_fp4"
): ):
return self.forward_flashinfer_cutedsl( output = self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
) )
elif self.use_w4afp8: elif self.use_w4afp8:
return self.forward_cutlass_w4afp8_masked(dispatch_output) output = self.forward_cutlass_w4afp8_masked(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 else:
assert down_gemm_overlap_args is None assert False, "forward_deepgemm_masked is deprecated"
return self.forward_deepgemm_masked(dispatch_output)
else: combine_input_wrapper = (
raise ValueError( DeepEPNormalCombineInput
f"Dispatch output format {dispatch_output.format} is not supported" if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
) else DeepEPLLCombineInput
)
return combine_input_wrapper(
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
overlap_args=down_gemm_overlap_args,
)
def combine( def combine(
self, self,
...@@ -226,7 +254,7 @@ class DeepEPMoE(FusedMoE): ...@@ -226,7 +254,7 @@ class DeepEPMoE(FusedMoE):
def forward_aiter( def forward_aiter(
self, self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput], dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
): ):
hidden_states, topk_ids, topk_weights = ( hidden_states, topk_ids, topk_weights = (
dispatch_output.hidden_states, dispatch_output.hidden_states,
...@@ -258,158 +286,9 @@ class DeepEPMoE(FusedMoE): ...@@ -258,158 +286,9 @@ class DeepEPMoE(FusedMoE):
expert_mask=self.expert_mask, expert_mask=self.expert_mask,
) )
def forward_deepgemm_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
(
hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
M, K = hidden_states.size()
N = self.w13_weight.size(1)
scale_block_size = 128
w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
hidden_states_shape = hidden_states.shape
hidden_states_device = hidden_states.device
hidden_states_dtype = hidden_states.dtype
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
),
(
# TODO check whether need `zeros`
torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states.device,
dtype=torch.float32,
)
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_ids)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states,
hidden_states_scale,
topk_ids,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)
del down_input_fp8, down_input_scale
gather_out = torch.empty(
hidden_states_shape,
device=hidden_states_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
return gather_out
def forward_flashinfer_cutedsl( def forward_flashinfer_cutedsl(
self, self,
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLDispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs], down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
): ):
hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
...@@ -427,7 +306,7 @@ class DeepEPMoE(FusedMoE): ...@@ -427,7 +306,7 @@ class DeepEPMoE(FusedMoE):
def forward_cutlass_w4afp8( def forward_cutlass_w4afp8(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalDispatchOutput,
): ):
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod) assert isinstance(self.quant_method, W4AFp8MoEMethod)
...@@ -436,90 +315,9 @@ class DeepEPMoE(FusedMoE): ...@@ -436,90 +315,9 @@ class DeepEPMoE(FusedMoE):
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
) )
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
assert hidden_states_scale.dtype == torch.float32 or (
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and hidden_states_scale.dtype == torch.int32
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}, DEEPGEMM_SCALE_UE8M0: {deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0}"
# GroupGemm-0
num_groups, m, k = hidden_states.size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(hidden_states, hidden_states_scale),
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
return down_output
def forward_cutlass_w4afp8_masked( def forward_cutlass_w4afp8_masked(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPLLDispatchOutput,
): ):
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod) assert isinstance(self.quant_method, W4AFp8MoEMethod)
...@@ -533,7 +331,7 @@ class DeepEPMoE(FusedMoE): ...@@ -533,7 +331,7 @@ class DeepEPMoE(FusedMoE):
def forward_npu( def forward_npu(
self, self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput], dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
): ):
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
...@@ -546,9 +344,9 @@ class DeepEPMoE(FusedMoE): ...@@ -546,9 +344,9 @@ class DeepEPMoE(FusedMoE):
output_dtype = torch.bfloat16 output_dtype = torch.bfloat16
group_list_type = 1 group_list_type = 1
def _forward_normal(dispatch_output: DeepEPNormalOutput): def _forward_normal(dispatch_output: DeepEPNormalDispatchOutput):
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPNormalOutput) assert isinstance(dispatch_output, DeepEPNormalDispatchOutput)
hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = ( hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
dispatch_output dispatch_output
) )
...@@ -618,9 +416,9 @@ class DeepEPMoE(FusedMoE): ...@@ -618,9 +416,9 @@ class DeepEPMoE(FusedMoE):
return hidden_states return hidden_states
def _forward_ll(dispatch_output: DeepEPLLOutput): def _forward_ll(dispatch_output: DeepEPLLDispatchOutput):
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPLLOutput) assert isinstance(dispatch_output, DeepEPLLDispatchOutput)
( (
hidden_states, hidden_states,
hidden_states_scale, hidden_states_scale,
...@@ -731,12 +529,3 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): ...@@ -731,12 +529,3 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_runner_backend().is_flashinfer_cutlass(): if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE return FusedMoE
return FusedMoE return FusedMoE
def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu
...@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module): ...@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module):
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
**kwargs, **kwargs,
) )
final_hidden_states = self.dispatcher.combine(combine_input) final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here? # TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[ final_hidden_states = final_hidden_states[
......
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.moe_runner.base import ( from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo, MoeQuantInfo,
MoeRunnerConfig, MoeRunnerConfig,
...@@ -15,14 +16,28 @@ from sglang.srt.layers.moe.moe_runner.base import ( ...@@ -15,14 +16,28 @@ from sglang.srt.layers.moe.moe_runner.base import (
register_pre_permute, register_pre_permute,
) )
from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPLLCombineInput,
DeepEPLLDispatchOutput,
DeepEPNormalCombineInput,
DeepEPNormalDispatchOutput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import ( from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput, StandardCombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
# TODO(kaixih@nvidia): ideally we should merge this logic into # TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
...@@ -40,13 +55,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor: ...@@ -40,13 +55,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
return new_x.transpose(1, 2).contiguous().transpose(1, 2) return new_x.transpose(1, 2).contiguous().transpose(1, 2)
def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu
@dataclass @dataclass
class DeepGemmRunnerInput(RunnerInput): class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool use_masked_gemm: bool
masked_m: Optional[torch.Tensor] = None
expected_m: Optional[int] = None
m_indices: Optional[torch.Tensor] = None
@property @property
def runner_backend(self) -> MoeRunnerBackend: def runner_backend(self) -> MoeRunnerBackend:
...@@ -84,20 +109,100 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -84,20 +109,100 @@ class DeepGemmRunnerCore(MoeRunnerCore):
running_state: dict, running_state: dict,
) -> DeepGemmRunnerOutput: ) -> DeepGemmRunnerOutput:
if runner_input.use_masked_gemm: if not runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm( hidden_states = self._run_contiguous_gemm(
runner_input, runner_input, quant_info, running_state
quant_info,
running_state,
) )
else: else:
hidden_states = self._run_contiguous_gemm( hidden_states = self._run_masked_gemm(
runner_input, runner_input, quant_info, running_state
quant_info,
running_state,
) )
return DeepGemmRunnerOutput(hidden_states=hidden_states) return DeepGemmRunnerOutput(hidden_states=hidden_states)
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
all_tokens = running_state["all_tokens"]
hidden_states_device = running_state["hidden_states_device"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_shape = running_state["hidden_states_shape"]
m_indices = runner_input.m_indices
N = quant_info.w13_weight.size(1)
K = hidden_states_shape[1]
scale_block_size = 128
w13_weight_fp8 = (
quant_info.w13_weight,
quant_info.w13_scale,
)
w2_weight_fp8 = (quant_info.w2_weight, quant_info.w2_scale)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = tma_align_input_scale(hidden_states_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(hidden_states, hidden_states_scale),
w13_weight_fp8,
gateup_output,
m_indices,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)
return down_output
def _run_masked_gemm( def _run_masked_gemm(
self, self,
runner_input: DeepGemmRunnerInput, runner_input: DeepGemmRunnerInput,
...@@ -149,6 +254,7 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -149,6 +254,7 @@ class DeepGemmRunnerCore(MoeRunnerCore):
expected_m, expected_m,
) )
dispose_tensor(hidden_states) dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)
# Act # Act
down_input = torch.empty( down_input = torch.empty(
...@@ -198,18 +304,9 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -198,18 +304,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
masked_m, masked_m,
expected_m, expected_m,
) )
del down_input
return down_output return down_output
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass
@property @property
def runner_backend(self) -> MoeRunnerBackend: def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM return MoeRunnerBackend.DEEP_GEMM
...@@ -222,6 +319,7 @@ def pre_permute_standard_to_deep_gemm( ...@@ -222,6 +319,7 @@ def pre_permute_standard_to_deep_gemm(
runner_config: MoeRunnerConfig, runner_config: MoeRunnerConfig,
running_state: dict, running_state: dict,
) -> DeepGemmRunnerInput: ) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
hidden_states, topk_output = dispatch_output hidden_states, topk_output = dispatch_output
...@@ -257,9 +355,9 @@ def pre_permute_standard_to_deep_gemm( ...@@ -257,9 +355,9 @@ def pre_permute_standard_to_deep_gemm(
return DeepGemmRunnerInput( return DeepGemmRunnerInput(
hidden_states=hidden_states, hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale, hidden_states_scale=hidden_states_scale,
use_masked_gemm=True,
masked_m=masked_m, masked_m=masked_m,
expected_m=expected_m, expected_m=expected_m,
use_masked_gemm=True,
) )
...@@ -302,3 +400,170 @@ def post_permute_deep_gemm_to_standard( ...@@ -302,3 +400,170 @@ def post_permute_deep_gemm_to_standard(
return StandardCombineInput( return StandardCombineInput(
hidden_states=output, hidden_states=output,
) )
@register_pre_permute("deepep_ll", "deep_gemm")
def pre_permute_deepep_ll_to_deep_gemm(
dispatch_output: DeepEPLLDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
hidden_states, hidden_states_scale, topk_ids, topk_weights, masked_m, expected_m = (
dispatch_output
)
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states.shape
running_state["hidden_states_dtype"] = hidden_states.dtype
running_state["hidden_states_device"] = hidden_states.device
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
use_masked_gemm=True,
masked_m=masked_m,
expected_m=expected_m,
)
@register_post_permute("deep_gemm", "deepep_ll")
def post_permute_deep_gemm_to_deepep_ll(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepEPLLCombineInput:
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput
return DeepEPLLCombineInput(
hidden_states=runner_output.hidden_states,
topk_ids=running_state["topk_ids"],
topk_weights=running_state["topk_weights"],
)
@register_pre_permute("deepep_normal", "deep_gemm")
def pre_permute_deepep_normal_to_deep_gemm(
dispatch_output: DeepEPNormalDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import ep_scatter
(
hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
running_state["all_tokens"] = all_tokens
K = hidden_states.shape[1]
hidden_states_shape = hidden_states.shape
hidden_states_device = hidden_states.device
hidden_states_dtype = hidden_states.dtype
running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_device"] = hidden_states_device
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
input_tensor = torch.empty(
(all_tokens, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
# TODO check whether need `zeros`
input_tensor_scale = torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states.device,
dtype=torch.int,
).transpose(0, 1)
else:
input_tensor_scale = torch.empty(
(all_tokens, K // 128),
device=hidden_states.device,
dtype=torch.float32,
)
m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
output_index = torch.empty_like(topk_ids)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states,
hidden_states_scale,
topk_ids,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor,
input_tensor_scale,
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)
running_state["output_index"] = output_index
return DeepGemmRunnerInput(
hidden_states=input_tensor,
hidden_states_scale=input_tensor_scale,
use_masked_gemm=False,
m_indices=m_indices,
)
@register_post_permute("deep_gemm", "deepep_normal")
def post_permute_deep_gemm_to_deepep_normal(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepEPNormalCombineInput:
from sglang.srt.layers.moe.ep_moe.kernels import ep_gather
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput
hidden_states = runner_output.hidden_states
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output_index = running_state["output_index"]
gather_out = torch.empty(
running_state["hidden_states_shape"],
device=running_state["hidden_states_device"],
dtype=torch.bfloat16,
)
ep_gather(hidden_states, topk_ids, topk_weights, output_index, gather_out)
return DeepEPNormalCombineInput(
hidden_states=gather_out,
topk_ids=running_state["topk_ids"],
topk_weights=running_state["topk_weights"],
)
...@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import ( ...@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPConfig, DeepEPConfig,
DeepEPDispatcher, DeepEPDispatcher,
DeepEPLLCombineInput, DeepEPLLCombineInput,
DeepEPLLOutput, DeepEPLLDispatchOutput,
DeepEPNormalCombineInput, DeepEPNormalCombineInput,
DeepEPNormalOutput, DeepEPNormalDispatchOutput,
) )
from sglang.srt.layers.moe.token_dispatcher.mooncake import ( from sglang.srt.layers.moe.token_dispatcher.mooncake import (
MooncakeCombineInput, MooncakeCombineInput,
...@@ -44,8 +44,8 @@ __all__ = [ ...@@ -44,8 +44,8 @@ __all__ = [
"StandardCombineInput", "StandardCombineInput",
"DeepEPConfig", "DeepEPConfig",
"DeepEPDispatcher", "DeepEPDispatcher",
"DeepEPNormalOutput", "DeepEPNormalDispatchOutput",
"DeepEPLLOutput", "DeepEPLLDispatchOutput",
"DeepEPLLCombineInput", "DeepEPLLCombineInput",
"DeepEPNormalCombineInput", "DeepEPNormalCombineInput",
] ]
...@@ -9,9 +9,9 @@ import torch ...@@ -9,9 +9,9 @@ import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLCombineInput, DeepEPLLCombineInput,
DeepEPLLOutput, DeepEPLLDispatchOutput,
DeepEPNormalCombineInput, DeepEPNormalCombineInput,
DeepEPNormalOutput, DeepEPNormalDispatchOutput,
StandardCombineInput, StandardCombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
...@@ -37,19 +37,19 @@ class DispatchOutputChecker: ...@@ -37,19 +37,19 @@ class DispatchOutputChecker:
@staticmethod @staticmethod
def format_is_deepep_normal( def format_is_deepep_normal(
dispatch_output: DispatchOutput, dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPNormalOutput]: ) -> TypeGuard[DeepEPNormalDispatchOutput]:
return dispatch_output.format.is_deepep_normal() return dispatch_output.format.is_deepep_normal()
@staticmethod @staticmethod
def format_is_deepep_ll( def format_is_deepep_ll(
dispatch_output: DispatchOutput, dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPLLOutput]: ) -> TypeGuard[DeepEPLLDispatchOutput]:
return dispatch_output.format.is_deepep_ll() return dispatch_output.format.is_deepep_ll()
@staticmethod @staticmethod
def format_is_deepep( def format_is_deepep(
dispatch_output: DispatchOutput, dispatch_output: DispatchOutput,
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]: ) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]:
return dispatch_output.format.is_deepep() return dispatch_output.format.is_deepep()
......
...@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() ...@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeepEPNormalOutput(NamedTuple): class DeepEPNormalDispatchOutput(NamedTuple):
"""DeepEP normal dispatch output.""" """DeepEP normal dispatch output."""
hidden_states: torch.Tensor hidden_states: torch.Tensor
...@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple): ...@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple):
return DispatchOutputFormat.DEEPEP_NORMAL return DispatchOutputFormat.DEEPEP_NORMAL
class DeepEPLLOutput(NamedTuple): class DeepEPLLDispatchOutput(NamedTuple):
"""DeepEP low latency dispatch output.""" """DeepEP low latency dispatch output."""
hidden_states: torch.Tensor hidden_states: torch.Tensor
...@@ -87,14 +87,17 @@ class DeepEPLLOutput(NamedTuple): ...@@ -87,14 +87,17 @@ class DeepEPLLOutput(NamedTuple):
return DispatchOutputFormat.DEEPEP_LL return DispatchOutputFormat.DEEPEP_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput) assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput) assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
class DeepEPNormalCombineInput(NamedTuple): class DeepEPNormalCombineInput(NamedTuple):
"""DeepEP normal combine input.""" """DeepEP normal combine input."""
pass hidden_states: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property @property
def format(self) -> CombineInputFormat: def format(self) -> CombineInputFormat:
...@@ -104,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple): ...@@ -104,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple):
class DeepEPLLCombineInput(NamedTuple): class DeepEPLLCombineInput(NamedTuple):
"""DeepEP low latency combine input.""" """DeepEP low latency combine input."""
pass hidden_states: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property @property
def format(self) -> CombineInputFormat: def format(self) -> CombineInputFormat:
...@@ -383,7 +389,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -383,7 +389,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
else: else:
hidden_states_scale = None hidden_states_scale = None
return DeepEPNormalOutput( return DeepEPNormalDispatchOutput(
hidden_states, hidden_states,
hidden_states_scale, hidden_states_scale,
topk_ids, topk_ids,
...@@ -562,7 +568,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -562,7 +568,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
else: else:
hidden_states_scale = None hidden_states_scale = None
deepep_output = DeepEPLLOutput( deepep_output = DeepEPLLDispatchOutput(
hidden_states, hidden_states,
hidden_states_scale, hidden_states_scale,
topk_ids, topk_ids,
...@@ -756,18 +762,16 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -756,18 +762,16 @@ class DeepEPDispatcher(BaseDispatcher):
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl().dispatch_b(*inner_state) return self._get_impl().dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple: def combine(self, combine_input: CombineInput) -> Tuple:
self.combine_a(*args, **kwargs) self.combine_a(combine_input)
ret = self.combine_b() ret = self.combine_b()
return ret return ret
def combine_a( def combine_a(
self, self,
hidden_states: torch.Tensor, combine_input: CombineInput,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"] = None,
): ):
hidden_states, topk_ids, topk_weights, overlap_args = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl().combine_a( inner_state = self._get_impl().combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
......
...@@ -984,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -984,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config moe_runner_config = self.moe_runner_config
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = dispatch_output.topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
...@@ -1017,7 +1016,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1017,7 +1016,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
ret = self.maybe_apply_hip_fused_experts( ret = self.maybe_apply_hip_fused_experts(
layer, layer,
x, x,
topk_output, dispatch_output.topk_output,
moe_runner_config.activation, moe_runner_config.activation,
moe_runner_config.no_combine, moe_runner_config.no_combine,
) )
...@@ -1027,7 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1027,7 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self._should_use_cutlass_fused_experts(): if self._should_use_cutlass_fused_experts():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = dispatch_output.topk_output
output = cutlass_fused_experts_fp8( output = cutlass_fused_experts_fp8(
x, x,
layer.w13_weight.transpose(1, 2), layer.w13_weight.transpose(1, 2),
......
...@@ -23,8 +23,8 @@ if TYPE_CHECKING: ...@@ -23,8 +23,8 @@ if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
CombineInput, CombineInput,
DeepEPLLOutput, DeepEPLLDispatchOutput,
DeepEPNormalOutput, DeepEPNormalDispatchOutput,
StandardDispatchOutput, StandardDispatchOutput,
) )
...@@ -332,7 +332,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -332,7 +332,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply_deepep_ll( def apply_deepep_ll(
self, self,
layer: DeepEPMoE, layer: DeepEPMoE,
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLDispatchOutput,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll
...@@ -367,7 +367,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -367,7 +367,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply_deepep_normal( def apply_deepep_normal(
self, self,
layer: DeepEPMoE, layer: DeepEPMoE,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalDispatchOutput,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.cutlass_w4a8_moe import ( from sglang.srt.layers.moe.cutlass_w4a8_moe import (
cutlass_w4a8_moe_deepep_normal, cutlass_w4a8_moe_deepep_normal,
......
...@@ -1005,16 +1005,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -1005,16 +1005,14 @@ class DeepseekV2MoE(nn.Module):
) )
def op_experts(self, state): def op_experts(self, state):
state.hidden_states_experts_output = self.experts.run_moe_core( state.combine_input = self.experts.run_moe_core(
dispatch_output=state.dispatch_output, dispatch_output=state.dispatch_output,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.experts.dispatcher.combine_a( self.experts.dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"), combine_input=state.pop("combine_input"),
topk_ids=state.dispatch_output.topk_ids,
topk_weights=state.dispatch_output.topk_weights,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
state.pop("dispatch_output") state.pop("dispatch_output")
......
...@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
) )
def op_experts(self, state): def op_experts(self, state):
state.hidden_states_experts_output = self.experts.run_moe_core( state.combine_input = self.experts.run_moe_core(
dispatch_output=state.dispatch_output, dispatch_output=state.dispatch_output,
) )
def op_combine_a(self, state): def op_combine_a(self, state):
if self.ep_size > 1: if self.ep_size > 1:
self.experts.dispatcher.combine_a( self.experts.dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"), combine_input=state.pop("combine_input"),
topk_ids=state.dispatch_output.topk_ids,
topk_weights=state.dispatch_output.topk_weights,
tbo_subbatch_index=state.get("tbo_subbatch_index"), tbo_subbatch_index=state.get("tbo_subbatch_index"),
) )
state.pop("dispatch_output") state.pop("dispatch_output")
......
...@@ -85,7 +85,7 @@ def execute_sbo( ...@@ -85,7 +85,7 @@ def execute_sbo(
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo) _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
) )
hidden_states = experts.run_moe_core( combine_input = experts.run_moe_core(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
) )
if (e := meta_overlap_args.get("record_event_after_down")) is not None: if (e := meta_overlap_args.get("record_event_after_down")) is not None:
...@@ -98,12 +98,7 @@ def execute_sbo( ...@@ -98,12 +98,7 @@ def execute_sbo(
): ):
forward_shared_experts() forward_shared_experts()
hidden_states = experts.dispatcher.combine( hidden_states = experts.dispatcher.combine(combine_input=combine_input)
hidden_states=hidden_states,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
overlap_args=combine_overlap_args,
)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment