Unverified Commit e67276ec authored by tql.99's avatar tql.99 Committed by GitHub
Browse files

feat: support cutlass_moe_fp8 kernel for fusedmoe in sm90 (#8678)

parent 0242bb9c
...@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple ...@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8( ...@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
if is_cuda: if is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8_hopper_moe_mn_major,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
...@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8( ...@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
n = w2_q.size(1) n = w2_q.size(1)
topk = topk_ids.size(1) topk = topk_ids.size(1)
device = a.device
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
device = a_q.device
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
...@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8( ...@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
k, k,
) )
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) if is_sm100_supported():
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
else:
rep_a = shuffle_rows(a, a_map, (m * topk, k))
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
rep_a, expert_offsets, problem_sizes1, 128
)
w1_scale = w1_scale.contiguous()
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
...@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8( ...@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate) silu_and_mul(c1, intermediate)
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) if is_sm100_supported():
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
else:
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
intermediate, expert_offsets, problem_sizes2, 128
)
w2_scale = w2_scale.contiguous()
fp8_blockwise_scaled_grouped_mm( fp8_blockwise_scaled_grouped_mm(
c2, c2,
......
...@@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -619,7 +619,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -619,7 +619,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if ( if (
get_bool_env_var("SGLANG_CUTLASS_MOE") get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported and self.cutlass_fp8_supported
and is_sm100_supported() and (is_sm100_supported() or is_sm90_supported())
): ):
self.ab_strides1 = torch.full( self.ab_strides1 = torch.full(
(num_experts,), (num_experts,),
...@@ -1034,7 +1034,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1034,7 +1034,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
get_bool_env_var("SGLANG_CUTLASS_MOE") get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported and self.cutlass_fp8_supported
and self.block_quant and self.block_quant
and is_sm100_supported() and (is_sm100_supported() or is_sm90_supported())
): ):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
......
import logging import logging
import re import re
from functools import lru_cache
import torch import torch
...@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity): ...@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
return (input,) if self.return_tuple else input return (input,) if self.return_tuple else input
@lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool: def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and ( return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8" torch.version.cuda >= "12.8"
) )
@lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment