Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5f91c825
Unverified
Commit
5f91c825
authored
Jun 06, 2025
by
Jianan Ji
Committed by
GitHub
Jun 06, 2025
Browse files
[Feature] Support Flashinfer fmha on Blackwell (#6930)
parent
b819381f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
9 deletions
+18
-9
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+5
-1
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+5
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+1
-1
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+1
-6
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+6
-0
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
5f91c825
...
...
@@ -25,6 +25,7 @@ from sglang.global_config import global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
is_flashinfer_available
,
next_power_of_2
...
...
@@ -149,8 +150,11 @@ class FlashInferAttnBackend(AttentionBackend):
for
_
in
range
(
self
.
num_wrappers
)
]
fmha_backend
=
"auto"
if
is_sm100_supported
():
fmha_backend
=
"cutlass"
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
self
.
workspace_buffer
,
"NHD"
,
backend
=
fmha_backend
)
# Two wrappers: one for sliding window attention and one for full attention.
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
5f91c825
...
...
@@ -29,6 +29,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
...
@@ -108,8 +109,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else
:
self
.
q_indptr_decode
=
q_indptr_decode_buf
fmha_backend
=
"auto"
if
is_sm100_supported
():
fmha_backend
=
"cutlass"
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
self
.
workspace_buffer
,
"NHD"
,
backend
=
fmha_backend
)
if
not
self
.
skip_prefill
:
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
5f91c825
...
...
@@ -52,7 +52,6 @@ from sglang.srt.layers.quantization.fp8_utils import (
cutlass_fp8_supported
,
dispatch_w8a8_block_fp8_linear
,
input_to_float8
,
is_sm100_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
...
...
@@ -63,6 +62,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize
,
requantize_with_max_scale
,
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_cuda
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
5f91c825
...
...
@@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.utils
import
is_sm100_supported
try
:
from
vllm
import
_custom_ops
as
ops
...
...
@@ -83,12 +84,6 @@ def cutlass_fp8_supported():
return
False
def
is_sm100_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
torch
.
version
.
cuda
>=
"12.8"
)
def
normalize_e4m3fn_to_e4m3fnuz
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/utils.py
View file @
5f91c825
...
...
@@ -33,3 +33,9 @@ class PPMissingLayer(torch.nn.Identity):
"""
input
=
args
[
0
]
if
args
else
next
(
iter
(
kwargs
.
values
()))
return
(
input
,)
if
self
.
return_tuple
else
input
def
is_sm100_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
torch
.
version
.
cuda
>=
"12.8"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment