Unverified Commit 791b3bfa authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (#6479)

parent 31589e17
...@@ -57,6 +57,10 @@ SGLang supports various environment variables that can be used to configure its ...@@ -57,6 +57,10 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` | | `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
| `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` |
| `SGLANG_CUTLASS_MOE` | Use Cutlass FP8 MoE kernel on Blackwell GPUs | `false` |
## Distributed Computing ## Distributed Computing
......
...@@ -571,7 +571,7 @@ class Fp8MoEMethod: ...@@ -571,7 +571,7 @@ class Fp8MoEMethod:
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
if ( if (
get_bool_env_var("CUTLASS_MOE") get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported and self.cutlass_fp8_supported
and is_sm100_supported() and is_sm100_supported()
): ):
...@@ -973,7 +973,7 @@ class Fp8MoEMethod: ...@@ -973,7 +973,7 @@ class Fp8MoEMethod:
return ret return ret
if ( if (
get_bool_env_var("CUTLASS_MOE") get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported and self.cutlass_fp8_supported
and self.block_quant and self.block_quant
and is_sm100_supported() and is_sm100_supported()
......
...@@ -28,6 +28,7 @@ from sglang.srt.utils import ( ...@@ -28,6 +28,7 @@ from sglang.srt.utils import (
get_cuda_version, get_cuda_version,
get_device_capability, get_device_capability,
is_cuda, is_cuda,
is_flashinfer_available,
is_hip, is_hip,
) )
...@@ -35,6 +36,7 @@ _is_hip = is_hip() ...@@ -35,6 +36,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip and use_aiter_moe: if _is_hip and use_aiter_moe:
...@@ -111,7 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz( ...@@ -111,7 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
def cutlass_block_fp8_supported() -> bool: def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"): if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False return False
if _is_cuda: if _is_cuda:
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
...@@ -123,6 +125,13 @@ def cutlass_block_fp8_supported() -> bool: ...@@ -123,6 +125,13 @@ def cutlass_block_fp8_supported() -> bool:
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
ENABLE_FLASHINFER_GEMM = (
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
and is_sm100_supported()
and is_flashinfer_available()
)
if ENABLE_FLASHINFER_GEMM:
from flashinfer.gemm import gemm_fp8_nt_groupwise
def apply_w8a8_block_fp8_linear( def apply_w8a8_block_fp8_linear(
...@@ -141,7 +150,16 @@ def apply_w8a8_block_fp8_linear( ...@@ -141,7 +150,16 @@ def apply_w8a8_block_fp8_linear(
shape_supported_by_cutlass = ( shape_supported_by_cutlass = (
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
) )
if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass: if ENABLE_FLASHINFER_GEMM:
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
x_scale_input = x_scale.T.contiguous()
weight_scale_input = weight_scale.T.contiguous()
output = gemm_fp8_nt_groupwise(
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
)
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=True input_2d, block_size[1], column_major_scales=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment