Unverified Commit 64994980 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[10/N] MoE Refactor: reorganize deepgemm runner in DeepEPMoE (#12054)

parent 729b2429
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import torch
......@@ -13,29 +13,23 @@ from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPLLCombineInput,
DeepEPNormalCombineInput,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
from sglang.srt.utils import get_bool_env_var, is_hip, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DeepEPLLDispatchOutput,
DeepEPNormalDispatchOutput,
DispatchOutput,
)
......@@ -45,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
pass
if _use_aiter:
from aiter import ActivationType, QuantType
......@@ -90,6 +84,18 @@ class DeepEPMoE(FusedMoE):
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter or _is_npu:
self.deprecate_flag = False
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and isinstance(
quant_config, Fp8Config
):
self.deprecate_flag = True
else:
self.deprecate_flag = False
if self.deprecate_flag:
return
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.use_fp8_w8a8 = True
......@@ -152,6 +158,14 @@ class DeepEPMoE(FusedMoE):
disable_sbo=False,
):
if self.deprecate_flag:
assert forward_shared_experts is None
assert alt_stream is None
return super().forward(
hidden_states,
topk_output,
)
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo(
hidden_states=hidden_states,
......@@ -178,37 +192,51 @@ class DeepEPMoE(FusedMoE):
dispatch_output: DispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
):
if self.deprecate_flag:
assert down_gemm_overlap_args is None
return super().run_moe_core(
dispatch_output,
)
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
output = self.forward_aiter(dispatch_output)
elif _is_npu:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
output = self.forward_npu(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8:
return self.forward_cutlass_w4afp8(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
output = self.forward_cutlass_w4afp8(dispatch_output)
else:
assert False, "forward_deepgemm_contiguous is deprecated"
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if (
get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4"
):
return self.forward_flashinfer_cutedsl(
output = self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
elif self.use_w4afp8:
return self.forward_cutlass_w4afp8_masked(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(
f"Dispatch output format {dispatch_output.format} is not supported"
)
output = self.forward_cutlass_w4afp8_masked(dispatch_output)
else:
assert False, "forward_deepgemm_masked is deprecated"
combine_input_wrapper = (
DeepEPNormalCombineInput
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else DeepEPLLCombineInput
)
return combine_input_wrapper(
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
overlap_args=down_gemm_overlap_args,
)
def combine(
self,
......@@ -226,7 +254,7 @@ class DeepEPMoE(FusedMoE):
def forward_aiter(
self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
):
hidden_states, topk_ids, topk_weights = (
dispatch_output.hidden_states,
......@@ -258,158 +286,9 @@ class DeepEPMoE(FusedMoE):
expert_mask=self.expert_mask,
)
def forward_deepgemm_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
(
hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
M, K = hidden_states.size()
N = self.w13_weight.size(1)
scale_block_size = 128
w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
hidden_states_shape = hidden_states.shape
hidden_states_device = hidden_states.device
hidden_states_dtype = hidden_states.dtype
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
),
(
# TODO check whether need `zeros`
torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states.device,
dtype=torch.float32,
)
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_ids)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states,
hidden_states_scale,
topk_ids,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)
del down_input_fp8, down_input_scale
gather_out = torch.empty(
hidden_states_shape,
device=hidden_states_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
return gather_out
def forward_flashinfer_cutedsl(
self,
dispatch_output: DeepEPLLOutput,
dispatch_output: DeepEPLLDispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
):
hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
......@@ -427,7 +306,7 @@ class DeepEPMoE(FusedMoE):
def forward_cutlass_w4afp8(
self,
dispatch_output: DeepEPNormalOutput,
dispatch_output: DeepEPNormalDispatchOutput,
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
......@@ -436,90 +315,9 @@ class DeepEPMoE(FusedMoE):
dispatch_output=dispatch_output,
)
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
assert hidden_states_scale.dtype == torch.float32 or (
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and hidden_states_scale.dtype == torch.int32
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}, DEEPGEMM_SCALE_UE8M0: {deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0}"
# GroupGemm-0
num_groups, m, k = hidden_states.size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(hidden_states, hidden_states_scale),
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
return down_output
def forward_cutlass_w4afp8_masked(
self,
dispatch_output: DeepEPNormalOutput,
dispatch_output: DeepEPLLDispatchOutput,
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
......@@ -533,7 +331,7 @@ class DeepEPMoE(FusedMoE):
def forward_npu(
self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
):
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
......@@ -546,9 +344,9 @@ class DeepEPMoE(FusedMoE):
output_dtype = torch.bfloat16
group_list_type = 1
def _forward_normal(dispatch_output: DeepEPNormalOutput):
def _forward_normal(dispatch_output: DeepEPNormalDispatchOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPNormalOutput)
assert isinstance(dispatch_output, DeepEPNormalDispatchOutput)
hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
dispatch_output
)
......@@ -618,9 +416,9 @@ class DeepEPMoE(FusedMoE):
return hidden_states
def _forward_ll(dispatch_output: DeepEPLLOutput):
def _forward_ll(dispatch_output: DeepEPLLDispatchOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPLLOutput)
assert isinstance(dispatch_output, DeepEPLLDispatchOutput)
(
hidden_states,
hidden_states_scale,
......@@ -731,12 +529,3 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
return FusedMoE
def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu
......@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module):
dispatch_output=dispatch_output,
**kwargs,
)
final_hidden_states = self.dispatcher.combine(combine_input)
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
......
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
......@@ -15,14 +16,28 @@ from sglang.srt.layers.moe.moe_runner.base import (
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPLLCombineInput,
DeepEPLLDispatchOutput,
DeepEPNormalCombineInput,
DeepEPNormalDispatchOutput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
......@@ -40,13 +55,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu
@dataclass
class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool
masked_m: Optional[torch.Tensor] = None
expected_m: Optional[int] = None
m_indices: Optional[torch.Tensor] = None
@property
def runner_backend(self) -> MoeRunnerBackend:
......@@ -84,20 +109,100 @@ class DeepGemmRunnerCore(MoeRunnerCore):
running_state: dict,
) -> DeepGemmRunnerOutput:
if runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm(
runner_input,
quant_info,
running_state,
if not runner_input.use_masked_gemm:
hidden_states = self._run_contiguous_gemm(
runner_input, quant_info, running_state
)
else:
hidden_states = self._run_contiguous_gemm(
runner_input,
quant_info,
running_state,
hidden_states = self._run_masked_gemm(
runner_input, quant_info, running_state
)
return DeepGemmRunnerOutput(hidden_states=hidden_states)
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
all_tokens = running_state["all_tokens"]
hidden_states_device = running_state["hidden_states_device"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_shape = running_state["hidden_states_shape"]
m_indices = runner_input.m_indices
N = quant_info.w13_weight.size(1)
K = hidden_states_shape[1]
scale_block_size = 128
w13_weight_fp8 = (
quant_info.w13_weight,
quant_info.w13_scale,
)
w2_weight_fp8 = (quant_info.w2_weight, quant_info.w2_scale)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = tma_align_input_scale(hidden_states_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(hidden_states, hidden_states_scale),
w13_weight_fp8,
gateup_output,
m_indices,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)
return down_output
def _run_masked_gemm(
self,
runner_input: DeepGemmRunnerInput,
......@@ -149,6 +254,7 @@ class DeepGemmRunnerCore(MoeRunnerCore):
expected_m,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)
# Act
down_input = torch.empty(
......@@ -198,18 +304,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
masked_m,
expected_m,
)
del down_input
return down_output
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
......@@ -222,6 +319,7 @@ def pre_permute_standard_to_deep_gemm(
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
hidden_states, topk_output = dispatch_output
......@@ -257,9 +355,9 @@ def pre_permute_standard_to_deep_gemm(
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
use_masked_gemm=True,
masked_m=masked_m,
expected_m=expected_m,
use_masked_gemm=True,
)
......@@ -302,3 +400,170 @@ def post_permute_deep_gemm_to_standard(
return StandardCombineInput(
hidden_states=output,
)
@register_pre_permute("deepep_ll", "deep_gemm")
def pre_permute_deepep_ll_to_deep_gemm(
dispatch_output: DeepEPLLDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
hidden_states, hidden_states_scale, topk_ids, topk_weights, masked_m, expected_m = (
dispatch_output
)
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states.shape
running_state["hidden_states_dtype"] = hidden_states.dtype
running_state["hidden_states_device"] = hidden_states.device
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
use_masked_gemm=True,
masked_m=masked_m,
expected_m=expected_m,
)
@register_post_permute("deep_gemm", "deepep_ll")
def post_permute_deep_gemm_to_deepep_ll(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepEPLLCombineInput:
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput
return DeepEPLLCombineInput(
hidden_states=runner_output.hidden_states,
topk_ids=running_state["topk_ids"],
topk_weights=running_state["topk_weights"],
)
@register_pre_permute("deepep_normal", "deep_gemm")
def pre_permute_deepep_normal_to_deep_gemm(
dispatch_output: DeepEPNormalDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import ep_scatter
(
hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
running_state["all_tokens"] = all_tokens
K = hidden_states.shape[1]
hidden_states_shape = hidden_states.shape
hidden_states_device = hidden_states.device
hidden_states_dtype = hidden_states.dtype
running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_device"] = hidden_states_device
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
input_tensor = torch.empty(
(all_tokens, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
# TODO check whether need `zeros`
input_tensor_scale = torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states.device,
dtype=torch.int,
).transpose(0, 1)
else:
input_tensor_scale = torch.empty(
(all_tokens, K // 128),
device=hidden_states.device,
dtype=torch.float32,
)
m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
output_index = torch.empty_like(topk_ids)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states,
hidden_states_scale,
topk_ids,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor,
input_tensor_scale,
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states)
dispose_tensor(hidden_states_scale)
running_state["output_index"] = output_index
return DeepGemmRunnerInput(
hidden_states=input_tensor,
hidden_states_scale=input_tensor_scale,
use_masked_gemm=False,
m_indices=m_indices,
)
@register_post_permute("deep_gemm", "deepep_normal")
def post_permute_deep_gemm_to_deepep_normal(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepEPNormalCombineInput:
from sglang.srt.layers.moe.ep_moe.kernels import ep_gather
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput
hidden_states = runner_output.hidden_states
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output_index = running_state["output_index"]
gather_out = torch.empty(
running_state["hidden_states_shape"],
device=running_state["hidden_states_device"],
dtype=torch.bfloat16,
)
ep_gather(hidden_states, topk_ids, topk_weights, output_index, gather_out)
return DeepEPNormalCombineInput(
hidden_states=gather_out,
topk_ids=running_state["topk_ids"],
topk_weights=running_state["topk_weights"],
)
......@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLCombineInput,
DeepEPLLOutput,
DeepEPLLDispatchOutput,
DeepEPNormalCombineInput,
DeepEPNormalOutput,
DeepEPNormalDispatchOutput,
)
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
MooncakeCombineInput,
......@@ -44,8 +44,8 @@ __all__ = [
"StandardCombineInput",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",
"DeepEPLLOutput",
"DeepEPNormalDispatchOutput",
"DeepEPLLDispatchOutput",
"DeepEPLLCombineInput",
"DeepEPNormalCombineInput",
]
......@@ -9,9 +9,9 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLCombineInput,
DeepEPLLOutput,
DeepEPLLDispatchOutput,
DeepEPNormalCombineInput,
DeepEPNormalOutput,
DeepEPNormalDispatchOutput,
StandardCombineInput,
StandardDispatchOutput,
)
......@@ -37,19 +37,19 @@ class DispatchOutputChecker:
@staticmethod
def format_is_deepep_normal(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPNormalOutput]:
) -> TypeGuard[DeepEPNormalDispatchOutput]:
return dispatch_output.format.is_deepep_normal()
@staticmethod
def format_is_deepep_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPLLOutput]:
) -> TypeGuard[DeepEPLLDispatchOutput]:
return dispatch_output.format.is_deepep_ll()
@staticmethod
def format_is_deepep(
dispatch_output: DispatchOutput,
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]:
return dispatch_output.format.is_deepep()
......
......@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger = logging.getLogger(__name__)
class DeepEPNormalOutput(NamedTuple):
class DeepEPNormalDispatchOutput(NamedTuple):
"""DeepEP normal dispatch output."""
hidden_states: torch.Tensor
......@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple):
return DispatchOutputFormat.DEEPEP_NORMAL
class DeepEPLLOutput(NamedTuple):
class DeepEPLLDispatchOutput(NamedTuple):
"""DeepEP low latency dispatch output."""
hidden_states: torch.Tensor
......@@ -87,14 +87,17 @@ class DeepEPLLOutput(NamedTuple):
return DispatchOutputFormat.DEEPEP_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(DeepEPNormalDispatchOutput, DispatchOutput)
assert isinstance(DeepEPLLDispatchOutput, DispatchOutput)
class DeepEPNormalCombineInput(NamedTuple):
"""DeepEP normal combine input."""
pass
hidden_states: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property
def format(self) -> CombineInputFormat:
......@@ -104,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple):
class DeepEPLLCombineInput(NamedTuple):
"""DeepEP low latency combine input."""
pass
hidden_states: torch.Tensor
topk_ids: torch.Tensor
topk_weights: torch.Tensor
overlap_args: Optional[CombineOverlapArgs] = None
@property
def format(self) -> CombineInputFormat:
......@@ -383,7 +389,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
else:
hidden_states_scale = None
return DeepEPNormalOutput(
return DeepEPNormalDispatchOutput(
hidden_states,
hidden_states_scale,
topk_ids,
......@@ -562,7 +568,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
else:
hidden_states_scale = None
deepep_output = DeepEPLLOutput(
deepep_output = DeepEPLLDispatchOutput(
hidden_states,
hidden_states_scale,
topk_ids,
......@@ -756,18 +762,16 @@ class DeepEPDispatcher(BaseDispatcher):
del self._dispatch_intermediate_state
return self._get_impl().dispatch_b(*inner_state)
def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
def combine(self, combine_input: CombineInput) -> Tuple:
self.combine_a(combine_input)
ret = self.combine_b()
return ret
def combine_a(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional["CombineOverlapArgs"] = None,
combine_input: CombineInput,
):
hidden_states, topk_ids, topk_weights, overlap_args = combine_input
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl().combine_a(
hidden_states=hidden_states,
......
......@@ -984,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
topk_weights, topk_ids, _ = dispatch_output.topk_output
x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
......@@ -1017,7 +1016,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
ret = self.maybe_apply_hip_fused_experts(
layer,
x,
topk_output,
dispatch_output.topk_output,
moe_runner_config.activation,
moe_runner_config.no_combine,
)
......@@ -1027,7 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self._should_use_cutlass_fused_experts():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output
topk_weights, topk_ids, _ = dispatch_output.topk_output
output = cutlass_fused_experts_fp8(
x,
layer.w13_weight.transpose(1, 2),
......
......@@ -23,8 +23,8 @@ if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
DeepEPLLOutput,
DeepEPNormalOutput,
DeepEPLLDispatchOutput,
DeepEPNormalDispatchOutput,
StandardDispatchOutput,
)
......@@ -332,7 +332,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply_deepep_ll(
self,
layer: DeepEPMoE,
dispatch_output: DeepEPLLOutput,
dispatch_output: DeepEPLLDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll
......@@ -367,7 +367,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply_deepep_normal(
self,
layer: DeepEPMoE,
dispatch_output: DeepEPNormalOutput,
dispatch_output: DeepEPNormalDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.cutlass_w4a8_moe import (
cutlass_w4a8_moe_deepep_normal,
......
......@@ -1005,16 +1005,14 @@ class DeepseekV2MoE(nn.Module):
)
def op_experts(self, state):
state.hidden_states_experts_output = self.experts.run_moe_core(
state.combine_input = self.experts.run_moe_core(
dispatch_output=state.dispatch_output,
)
def op_combine_a(self, state):
if self.ep_size > 1:
self.experts.dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"),
topk_ids=state.dispatch_output.topk_ids,
topk_weights=state.dispatch_output.topk_weights,
combine_input=state.pop("combine_input"),
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
state.pop("dispatch_output")
......
......@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
def op_experts(self, state):
state.hidden_states_experts_output = self.experts.run_moe_core(
state.combine_input = self.experts.run_moe_core(
dispatch_output=state.dispatch_output,
)
def op_combine_a(self, state):
if self.ep_size > 1:
self.experts.dispatcher.combine_a(
hidden_states=state.pop("hidden_states_experts_output"),
topk_ids=state.dispatch_output.topk_ids,
topk_weights=state.dispatch_output.topk_weights,
combine_input=state.pop("combine_input"),
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
state.pop("dispatch_output")
......
......@@ -85,7 +85,7 @@ def execute_sbo(
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
)
hidden_states = experts.run_moe_core(
combine_input = experts.run_moe_core(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
......@@ -98,12 +98,7 @@ def execute_sbo(
):
forward_shared_experts()
hidden_states = experts.dispatcher.combine(
hidden_states=hidden_states,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
overlap_args=combine_overlap_args,
)
hidden_states = experts.dispatcher.combine(combine_input=combine_input)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment