Unverified Commit 5f91c825 authored by Jianan Ji's avatar Jianan Ji Committed by GitHub
Browse files

[Feature] Support Flashinfer fmha on Blackwell (#6930)

parent b819381f
......@@ -25,6 +25,7 @@ from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
......@@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend):
for _ in range(self.num_wrappers)
]
fmha_backend = "auto"
if is_sm100_supported():
fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
self.workspace_buffer, "NHD", backend=fmha_backend
)
# Two wrappers: one for sliding window attention and one for full attention.
......
......@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
......@@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else:
self.q_indptr_decode = q_indptr_decode_buf
fmha_backend = "auto"
if is_sm100_supported():
fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
self.workspace_buffer, "NHD", backend=fmha_backend
)
if not self.skip_prefill:
......
......@@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
input_to_float8,
is_sm100_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
......@@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
......
......@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.utils import is_sm100_supported
try:
from vllm import _custom_ops as ops
......@@ -83,12 +84,6 @@ def cutlass_fp8_supported():
return False
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
......
......@@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity):
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input,) if self.return_tuple else input
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment