"git@developer.sourcefind.cn:change/sglang.git" did not exist on "7dd8a7e6d973f2dba7f669e92b40baeeb7983248"
Unverified Commit 5f91c825 authored by Jianan Ji's avatar Jianan Ji Committed by GitHub
Browse files

[Feature] Support Flashinfer fmha on Blackwell (#6930)

parent b819381f
...@@ -25,6 +25,7 @@ from sglang.global_config import global_config ...@@ -25,6 +25,7 @@ from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2 from sglang.srt.utils import is_flashinfer_available, next_power_of_2
...@@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend):
for _ in range(self.num_wrappers) for _ in range(self.num_wrappers)
] ]
fmha_backend = "auto"
if is_sm100_supported():
fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD" self.workspace_buffer, "NHD", backend=fmha_backend
) )
# Two wrappers: one for sliding window attention and one for full attention. # Two wrappers: one for sliding window attention and one for full attention.
......
...@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import ( ...@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else: else:
self.q_indptr_decode = q_indptr_decode_buf self.q_indptr_decode = q_indptr_decode_buf
fmha_backend = "auto"
if is_sm100_supported():
fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD" self.workspace_buffer, "NHD", backend=fmha_backend
) )
if not self.skip_prefill: if not self.skip_prefill:
......
...@@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
cutlass_fp8_supported, cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear, dispatch_w8a8_block_fp8_linear,
input_to_float8, input_to_float8,
is_sm100_supported,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
...@@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cuda, is_cuda,
......
...@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple ...@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.utils import is_sm100_supported
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -83,12 +84,6 @@ def cutlass_fp8_supported(): ...@@ -83,12 +84,6 @@ def cutlass_fp8_supported():
return False return False
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
......
...@@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity): ...@@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity):
""" """
input = args[0] if args else next(iter(kwargs.values())) input = args[0] if args else next(iter(kwargs.values()))
return (input,) if self.return_tuple else input return (input,) if self.return_tuple else input
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment