Unverified Commit 15ad6c90 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[1/N] MoE Refactor: refactor `select_experts` (#7966)

parent cfab0ff6
......@@ -58,7 +58,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
......@@ -303,6 +303,17 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
......@@ -311,13 +322,7 @@ class DeepseekV2MoE(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
......@@ -451,8 +456,9 @@ class DeepseekV2MoE(nn.Module):
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
hidden_states=hidden_states, topk_output=topk_output
)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
......@@ -473,8 +479,9 @@ class DeepseekV2MoE(nn.Module):
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
hidden_states=hidden_states, topk_output=topk_output
)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
......@@ -490,8 +497,9 @@ class DeepseekV2MoE(nn.Module):
) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
fused_experts_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
hidden_states=hidden_states, topk_output=topk_output
)
assert use_intel_amx_backend(
......@@ -549,17 +557,9 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
......@@ -649,17 +649,9 @@ class DeepseekV2MoE(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.topk_weights_local, state.topk_idx_local = select_experts(
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
......
......@@ -15,6 +15,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -60,6 +61,11 @@ class GraniteMoeMoE(nn.Module):
prefix=f"{prefix}.gate",
)
self.topk = TopK(
top_k=top_k,
renormalize=True,
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
......@@ -67,7 +73,6 @@ class GraniteMoeMoE(nn.Module):
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts",
......@@ -78,7 +83,8 @@ class GraniteMoeMoE(nn.Module):
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
return final_hidden_states.view(orig_shape)
......
......@@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -108,6 +109,12 @@ class Grok1MoE(nn.Module):
fused_moe_router_shim, self.router_logit_softcapping
)
self.topk = TopK(
top_k=top_k,
renormalize=False,
custom_routing_function=custom_routing_function,
)
kwargs = {}
if global_server_args_dict["enable_ep_moe"]:
MoEImpl = EPMoE
......@@ -124,17 +131,16 @@ class Grok1MoE(nn.Module):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
custom_routing_function=custom_routing_function,
activation="gelu",
**kwargs,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# need to assert self.gate.quant_method is unquantized
return self.experts(hidden_states, self.gate.weight)
topk_output = self.topk(hidden_states, self.gate.weight)
return self.experts(hidden_states, topk_output)
class Grok1Attention(nn.Module):
......
......@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -152,13 +153,16 @@ class HunYuanSparseMoeBlock(nn.Module):
else config.moe_intermediate_size[layer_id]
)
self.topk = TopK(
top_k=top_k,
renormalize=True if top_k > 1 else False,
)
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=True if top_k > 1 else False,
quant_config=quant_config,
)
......@@ -195,9 +199,8 @@ class HunYuanSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
......
......@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -103,14 +104,17 @@ class Llama4MoE(nn.Module):
prefix=add_prefix("router", prefix),
)
self.topk = TopK(
top_k=self.top_k,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
)
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
apply_router_weight_on_input=True,
prefix=add_prefix("experts", prefix),
......@@ -147,10 +151,8 @@ class Llama4MoE(nn.Module):
# router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
topk_output = self.topk(hidden_states, router_logits)
routed_out = self.experts(hidden_states, topk_output)
return shared_out, routed_out
def _forward_core_shared_routed_overlap(self, hidden_states):
......@@ -163,10 +165,8 @@ class Llama4MoE(nn.Module):
with self.device_module.stream(alt_stream):
# router_scores: [num_tokens, num_experts]
router_logits, _ = self.router(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
topk_output = self.topk(hidden_states, router_logits)
routed_out = self.experts(hidden_states, topk_output)
self.device_module.current_stream().wait_stream(alt_stream)
return shared_out, routed_out
......
......@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -86,6 +87,12 @@ class MixtralMoE(nn.Module):
quant_config=None,
prefix=add_prefix("gate", prefix),
)
self.topk = TopK(
top_k=top_k,
renormalize=True,
)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=num_experts,
......@@ -93,7 +100,6 @@ class MixtralMoE(nn.Module):
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=add_prefix("experts", prefix),
......@@ -105,7 +111,8 @@ class MixtralMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(orig_shape)
......
......@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -76,13 +77,16 @@ class OlmoeMoE(nn.Module):
prefix=add_prefix("gate", prefix),
)
self.topk = TopK(
top_k=top_k,
renormalize=False,
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
prefix=add_prefix("experts", prefix),
......@@ -94,9 +98,8 @@ class OlmoeMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
return final_hidden_states.view(orig_shape)
......
......@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
quant_config=None,
)
self.topk = TopK(
top_k=top_k,
renormalize=False,
custom_routing_function=phimoe_routing_function,
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
custom_routing_function=phimoe_routing_function,
prefix=add_prefix("experts", prefix),
)
......@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
return final_hidden_states.view(orig_shape)
......
......@@ -61,6 +61,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -134,13 +135,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
f"the number of experts {config.num_experts}."
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=config.norm_topk_prob,
)
self.experts = get_moe_impl_class()(
layer_id=self.layer_id,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
num_experts=config.num_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
# Additional args for FusedMoE
......@@ -189,9 +194,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
......
......@@ -56,8 +56,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -102,6 +101,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
f"the number of experts {config.num_experts}."
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=config.norm_topk_prob,
use_grouped_topk=False,
)
self.experts = get_moe_impl_class()(
num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
......@@ -109,7 +114,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
layer_id=layer_id,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
**(
......@@ -143,7 +147,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
......@@ -180,9 +183,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
......@@ -195,13 +197,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if is_non_idle_and_non_empty(forward_mode, hidden_states):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
......@@ -267,12 +265,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.topk_weights_local, state.topk_idx_local = select_experts(
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
......
......@@ -6,6 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8,
......@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
renormalize=False,
)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
......
......@@ -40,7 +40,7 @@ def ep_moe(
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
......
......@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
s_strides2 = c_strides2
score = torch.randn((M, E), dtype=dtype, device=device)
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
use_grouped_topk=False,
renormalize=False,
)
expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E
......
......@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = select_experts(
topk_weights, topk_ids, _ = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
use_grouped_topk=False,
renormalize=False,
)
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
......
......@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase
......@@ -171,14 +172,18 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
score = torch.randn((M, E), dtype=dtype)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
with torch.inference_mode():
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
......
......@@ -6,6 +6,7 @@ from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import is_hip
......@@ -132,13 +133,17 @@ class TestFusedMOE(CustomTestCase):
input_scale=a2_scale,
)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
sglang_output = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
......@@ -166,7 +171,13 @@ class TestFusedMOE(CustomTestCase):
w2 = self.create_random_cuda_tensor((e, k, n), dtype)
score = self.create_random_cuda_tensor((m, e), dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
triton_output = fused_moe(a, w1, w2, topk_output)
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(
triton_output, torch_output, rtol=rtol, atol=atol
......
......@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.test.test_utils import CustomTestCase
......@@ -114,13 +115,16 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=False, # Not using fp8
use_int8_w8a16=False, # Not using int8-w8a16
use_int8_w8a8=True, # Using int8-w8a8
......
......@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.test.test_utils import CustomTestCase
......@@ -126,13 +127,16 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
with torch.inference_mode():
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
topk_output,
use_fp8_w8a8=True, # using fp8
use_int8_w8a16=False,
use_int8_w8a8=False,
......
......@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
......@@ -219,13 +220,17 @@ def test_fused_moe_wn16(
if has_zp:
w_qzeros[expert_id] = qzeros
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
)
triton_output = fused_moe(
a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
topk_output,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=w1_scales,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment