Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_fp8, per_tensor_quant_mla_fp8,
per_token_group_quant_fp8, per_token_group_quant_fp8,
...@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase): ...@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
renormalize=False,
) )
out = fused_moe( out = fused_moe(
a, a,
...@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase): ...@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
w2_scale=w2_s, w2_scale=w2_s,
block_shape=block_size, block_shape=block_size,
) )
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue( self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
......
...@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
) )
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -22,35 +22,26 @@ def ep_moe( ...@@ -22,35 +22,26 @@ def ep_moe(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, topk_config: TopKConfig,
renormalize: bool,
# ep config # ep config
num_experts: int = 256, num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn, fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128, num_experts_per_partition: int = 128,
start_expert_id: int = 0, start_expert_id: int = 0,
end_expert_id: int = 127, end_expert_id: int = 127,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None, w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None, w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
use_blockwise_fp8 = block_shape is not None use_blockwise_fp8 = block_shape is not None
topk_weights, topk_ids, _ = select_experts( top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, topk_config=topk_config,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
# correction_bias=correction_bias, #skip this in test
custom_routing_function=custom_routing_function,
) )
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
...@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase): ...@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
start_id = cur_rank * num_experts_per_partition start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1 end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode(): with torch.inference_mode():
out = ep_moe( out = ep_moe(
hidden_states=a, hidden_states=a,
w1=w1, w1=w1,
w2=w2, w2=w2,
router_logits=score, router_logits=score,
top_k=topk, topk_config=topk_config,
renormalize=False,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale_inv=w1_s, w1_scale_inv=w1_s,
w2_scale_inv=w2_s, w2_scale_inv=w2_s,
...@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase): ...@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
w1=w1_ref, w1=w1_ref,
w2=w2_ref, w2=w2_ref,
router_logits=score, router_logits=score,
top_k=topk, topk_config=topk_config,
renormalize=False,
use_fp8_w8a8=False, use_fp8_w8a8=False,
w1_scale_inv=None, w1_scale_inv=None,
w2_scale_inv=None, w2_scale_inv=None,
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
...@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): ...@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
s_strides2 = c_strides2 s_strides2 = c_strides2
score = torch.randn((M, E), dtype=dtype, device=device) score = torch.randn((M, E), dtype=dtype, device=device)
topk_weights, topk_ids, _ = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
topk_weights, topk_ids, _ = topk_output
expert_map = torch.arange(E, dtype=torch.int32, device=device) expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E expert_map[local_e:] = E
......
...@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant ...@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
if torch.cuda.get_device_capability() < (10, 0): if torch.cuda.get_device_capability() < (10, 0):
pytest.skip( pytest.skip(
...@@ -163,11 +163,12 @@ def check_moe( ...@@ -163,11 +163,12 @@ def check_moe(
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
topk_weights, topk_ids, _ = topk_output
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -175,10 +175,13 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase): ...@@ -175,10 +175,13 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
with torch.inference_mode(): with torch.inference_mode():
ref_out = torch_w8a8_block_int8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
out = fused_moe( out = fused_moe(
a, a,
w1, w1,
...@@ -189,9 +192,6 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase): ...@@ -189,9 +192,6 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
w2_scale=w2_s, w2_scale=w2_s,
block_shape=block_size, block_shape=block_size,
) )
ref_out = torch_w8a8_block_int8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue( self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -118,7 +118,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase): ...@@ -118,7 +118,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
out = fused_moe( out = fused_moe(
a, a,
......
...@@ -6,7 +6,7 @@ from tqdm import tqdm ...@@ -6,7 +6,7 @@ from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
...@@ -136,19 +136,7 @@ class TestFusedMOE(CustomTestCase): ...@@ -136,19 +136,7 @@ class TestFusedMOE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
)
sglang_output = fused_moe(
a,
w1,
w2,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
) )
torch_output = self.torch_naive_moe( torch_output = self.torch_naive_moe(
...@@ -162,6 +150,18 @@ class TestFusedMOE(CustomTestCase): ...@@ -162,6 +150,18 @@ class TestFusedMOE(CustomTestCase):
a1_scale, a1_scale,
a2_scale, a2_scale,
) )
sglang_output = fused_moe(
a,
w1,
w2,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.testing.assert_close( torch.testing.assert_close(
sglang_output, torch_output, rtol=rtol, atol=atol sglang_output, torch_output, rtol=rtol, atol=atol
) )
...@@ -174,7 +174,7 @@ class TestFusedMOE(CustomTestCase): ...@@ -174,7 +174,7 @@ class TestFusedMOE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
triton_output = fused_moe(a, w1, w2, topk_output) triton_output = fused_moe(a, w1, w2, topk_output)
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -130,7 +130,7 @@ class TestW8A8FP8FusedMoE(CustomTestCase): ...@@ -130,7 +130,7 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
out = fused_moe( out = fused_moe(
a, a,
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6] TOP_KS = [2, 6]
...@@ -223,7 +223,7 @@ def test_fused_moe_wn16( ...@@ -223,7 +223,7 @@ def test_fused_moe_wn16(
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk),
) )
triton_output = fused_moe( triton_output = fused_moe(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment