Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1c1ebe9
Unverified
Commit
a1c1ebe9
authored
Jun 25, 2025
by
Yuhong Guo
Committed by
GitHub
Jun 25, 2025
Browse files
Fix FP8 KV Cache Support in FA3 Backend (#7148)
parent
fe2a0f96
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
18 deletions
+72
-18
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+24
-14
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-2
test/srt/test_mla_deepseek_v3.py
test/srt/test_mla_deepseek_v3.py
+44
-2
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a1c1ebe9
...
...
@@ -657,12 +657,16 @@ class FlashAttentionBackend(AttentionBackend):
)
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
k_scale
is
not
None
:
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
:
if
layer
.
k_scale
is
not
None
:
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
causal
=
not
layer
.
is_cross_attention
# Check if we should use local attention
...
...
@@ -776,8 +780,8 @@ class FlashAttentionBackend(AttentionBackend):
output
,
lse
,
*
rest
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
to
(
q
.
dtype
)
,
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
to
(
q
.
dtype
)
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
@@ -790,8 +794,8 @@ class FlashAttentionBackend(AttentionBackend):
# MHA for extend part of sequence without attending prefix kv cache
output
,
lse
,
*
rest
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
.
to
(
q
.
dtype
)
,
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
.
to
(
q
.
dtype
)
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
...
...
@@ -803,7 +807,9 @@ class FlashAttentionBackend(AttentionBackend):
return
output
,
lse
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
q
.
dtype
)
k_rope
=
kv_cache
[:,
:,
layer
.
v_head_dim
:]
c_kv
=
kv_cache
[:,
:,
:
layer
.
v_head_dim
]
k_rope_cache
=
k_rope
.
view
(
...
...
@@ -933,14 +939,16 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if
self
.
kv_cache_dtype_str
!=
"auto"
:
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if
self
.
kv_cache_dtype_str
!=
"auto"
and
layer
.
head_dim
<=
256
:
if
layer
.
k_scale
is
not
None
:
descale_shape
=
(
forward_batch
.
batch_size
,
layer
.
tp_k_head_num
)
k_descale
=
layer
.
k_scale
.
expand
(
descale_shape
)
v_descale
=
layer
.
v_scale
.
expand
(
descale_shape
)
q
=
q
.
to
(
self
.
kv_cache_dtype
)
q_rope
=
q_rope
.
to
(
self
.
kv_cache_dtype
)
if
q_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
# Do multi-head attention
...
...
@@ -1048,7 +1056,9 @@ class FlashAttentionBackend(AttentionBackend):
o
=
result
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
q
.
dtype
)
k_rope
=
kv_cache
[:,
:,
layer
.
v_head_dim
:]
c_kv
=
kv_cache
[:,
:,
:
layer
.
v_head_dim
]
k_rope_cache
=
k_rope
.
view
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a1c1ebe9
...
...
@@ -239,7 +239,7 @@ class ModelRunner:
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
):
logger
.
info
(
f
"Initial expert_location_metadata:
{
get_global_expert_location_metadata
()
.
debug_str
()
}
"
f
"Initial expert_location_metadata:
{
get_global_expert_location_metadata
()
}
"
)
set_global_expert_distribution_recorder
(
...
...
@@ -866,7 +866,9 @@ class ModelRunner:
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e5m2
elif
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
is_cuda
():
if
_is_hip
:
# Using natively supported format
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fnuz
else
:
self
.
kv_cache_dtype
=
torch
.
float8_e4m3fn
else
:
raise
ValueError
(
...
...
test/srt/test_mla_deepseek_v3.py
View file @
a1c1ebe9
...
...
@@ -4,7 +4,7 @@ from types import SimpleNamespace
import
requests
import
torch
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
...
@@ -20,7 +20,7 @@ class TestMLADeepseekV3(CustomTestCase):
cls
.
model
=
"lmsys/sglang-ci-dsv3-test"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
,
"--chunked-prefill-size"
,
"256"
]
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
:
if
is_
cuda
()
:
other_args
.
extend
([
"--enable-torch-compile"
,
"--cuda-graph-max-bs"
,
"2"
])
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
@@ -49,6 +49,48 @@ class TestMLADeepseekV3(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
@
unittest
.
skipIf
(
is_hip
(),
"FA is not available."
)
class
TestMLADeepseekV3Fa3Fp8Kvcache
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"lmsys/sglang-ci-dsv3-test"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
,
"--chunked-prefill-size"
,
"256"
,
"--kv-cache-dtype"
,
"fp8_e4m3"
,
]
if
is_cuda
():
other_args
.
extend
([
"--attention-backend"
,
"fa3"
])
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
class
TestDeepseekV3MTP
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment