Unverified Commit e67276ec authored by tql.99's avatar tql.99 Committed by GitHub
Browse files

feat: support cutlass_moe_fp8 kernel for fusedmoe in sm90 (#8678)

parent 0242bb9c
......@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
......@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
if is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8_hopper_moe_mn_major,
sglang_per_token_group_quant_fp8,
)
......@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
n = w2_q.size(1)
topk = topk_ids.size(1)
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
device = a_q.device
device = a.device
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
......@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
k,
)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
if is_sm100_supported():
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
else:
rep_a = shuffle_rows(a, a_map, (m * topk, k))
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
rep_a, expert_offsets, problem_sizes1, 128
)
w1_scale = w1_scale.contiguous()
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
......@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
if is_sm100_supported():
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
else:
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
intermediate, expert_offsets, problem_sizes2, 128
)
w2_scale = w2_scale.contiguous()
fp8_blockwise_scaled_grouped_mm(
c2,
......
......@@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -619,7 +619,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if (
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and is_sm100_supported()
and (is_sm100_supported() or is_sm90_supported())
):
self.ab_strides1 = torch.full(
(num_experts,),
......@@ -1034,7 +1034,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and is_sm100_supported()
and (is_sm100_supported() or is_sm90_supported())
):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
......
import logging
import re
from functools import lru_cache
import torch
......@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
return (input,) if self.return_tuple else input
@lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment