Unverified Commit 88596739 authored by weiliang's avatar weiliang Committed by GitHub
Browse files

Support running FP4 Deepseek on SM120. (#11708)

parent a6ea3add
......@@ -26,8 +26,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
get_int_env_var,
is_blackwell_supported,
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
......@@ -229,7 +229,7 @@ class FlashInferAttnBackend(AttentionBackend):
]
fmha_backend = "auto"
if is_sm100_supported():
if is_blackwell_supported():
# Disable CUTLASS backend when piecewise cuda graph is enabled
# due to TMA descriptor initialization issues on B200
if model_runner.server_args.enable_piecewise_cuda_graph:
......
......@@ -25,8 +25,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
is_blackwell_supported,
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
......@@ -243,7 +243,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.q_indptr_decode = q_indptr_decode_buf
self.fmha_backend = "auto"
if is_sm100_supported():
if is_blackwell_supported():
self.fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend
......
......@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader
try:
from vllm import _custom_ops as ops
......@@ -129,7 +129,7 @@ def cutlass_block_fp8_supported() -> bool:
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
ENABLE_FLASHINFER_GEMM = (
get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
and is_sm100_supported()
and is_blackwell_supported()
and is_flashinfer_available()
)
if ENABLE_FLASHINFER_GEMM:
......
......@@ -28,7 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
is_sm100_supported,
is_blackwell_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
......@@ -49,8 +49,10 @@ if TYPE_CHECKING:
)
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
if is_cuda():
from sgl_kernel import scaled_fp4_quant
try:
from flashinfer import fp4_quantize
except ImportError:
fp4_quantize = None
try:
from flashinfer import mm_fp4 as fp4_gemm
......@@ -867,10 +869,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
output_shape = [x_m, w_n]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
assert x_fp4.dtype == torch.uint8
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.weight.dtype == torch.uint8
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32
......@@ -903,7 +904,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config
if not is_sm100_supported():
if not is_blackwell_supported():
raise ValueError(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
......@@ -1410,7 +1411,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
output_dtype = x.dtype
x_sf = None
if should_use_flashinfer_cutlass_moe_fp4_allgather():
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
from flashinfer import nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if x.shape[0] > 0:
......
......@@ -131,13 +131,11 @@ from sglang.srt.utils import (
get_int_env_var,
is_cpu,
is_cuda,
is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_non_idle_and_non_empty,
is_npu,
is_nvidia_cublas_cu12_version_ge_12_9,
is_sm100_supported,
log_info_on_rank0,
make_layers,
use_intel_amx_backend,
......@@ -197,8 +195,6 @@ elif _is_npu:
else:
pass
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
logger = logging.getLogger(__name__)
......@@ -1260,7 +1256,7 @@ class DeepseekV2AttentionMLA(nn.Module):
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
and _is_cuda
and _device_sm >= 90
and 90 <= _device_sm < 120
)
self.qkv_proj_with_rope_is_int8 = (
......
......@@ -70,18 +70,9 @@ from sglang.srt.models.utils import (
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cuda,
is_flashinfer_available,
is_sm100_supported,
make_layers,
)
from sglang.srt.utils import LazyValue, add_prefix, is_cuda, make_layers
_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
if _is_cuda:
......
......@@ -39,6 +39,7 @@ from sglang.srt.utils.common import (
get_device,
get_device_memory_capacity,
get_device_sm,
is_blackwell_supported,
is_cuda,
is_fa3_default_architecture,
is_flashinfer_available,
......@@ -913,7 +914,7 @@ class ServerArgs:
f"- Decode: {decode_attn_backend}\n"
)
if is_sm100_supported():
if is_blackwell_supported():
if not self.enable_dp_attention:
self.enable_flashinfer_allreduce_fusion = True
logger.info(
......@@ -925,7 +926,7 @@ class ServerArgs:
and quantization_config.get("quant_method") == "mxfp4"
)
if is_sm100_supported() and is_mxfp4_quant_format:
if is_blackwell_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
......@@ -1145,7 +1146,7 @@ class ServerArgs:
self.attention_backend == "trtllm_mla"
or self.decode_attention_backend == "trtllm_mla"
):
if not is_sm100_supported():
if not is_blackwell_supported():
raise ValueError(
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
)
......
......@@ -188,7 +188,16 @@ is_hopper_with_cuda_12_3 = lambda: _check(9)
def is_blackwell():
if not is_cuda():
return False
return torch.cuda.get_device_capability()[0] == 10
return torch.cuda.get_device_capability()[0] in [10, 12]
@lru_cache(maxsize=1)
def is_blackwell_supported(device=None) -> bool:
if not is_cuda_alike():
return False
return (torch.cuda.get_device_capability(device)[0] in [10, 12]) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1)
......
......@@ -86,8 +86,8 @@ def baseline_scaled_mm(
).to(out_dtype)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
def is_blackwell_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] in [10, 12]) and (
torch.version.cuda >= "12.8"
)
......@@ -99,7 +99,7 @@ def is_sm90_supported(device=None) -> bool:
@pytest.mark.skipif(
not (is_sm100_supported() or is_sm90_supported()),
not (is_blackwell_supported() or is_sm90_supported()),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
)
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment