Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20c90be2
Unverified
Commit
20c90be2
authored
Mar 28, 2025
by
Baizhou Zhang
Committed by
GitHub
Mar 28, 2025
Browse files
[Feature] Support FA3 backend for MLA (#4831)
parent
ec3ee028
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
180 additions
and
74 deletions
+180
-74
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+171
-73
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+4
-0
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
20c90be2
...
...
@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
...
...
@@ -58,6 +60,9 @@ class FlashAttentionBackend(AttentionBackend):
self
.
decode_cuda_graph_metadata
=
{}
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
page_size
=
model_runner
.
page_size
self
.
use_mla
=
(
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
)
and
(
not
global_server_args_dict
[
"disable_mla"
])
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize forward metadata to cache repetitive calculations."""
...
...
@@ -117,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
# Use precomputed metadata
metadata
=
self
.
forward_metadata
# # Use Flash Attention for prefill
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
...
...
@@ -142,36 +154,72 @@ class FlashAttentionBackend(AttentionBackend):
if
layer
.
sliding_window_size
is
not
None
else
(
-
1
,
-
1
)
)
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
page_table
=
metadata
.
page_table
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
)
# # Use Flash Attention for prefill
if
not
self
.
use_mla
:
# Do multi-head attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
)
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_rope
=
kv_cache
[:,
:,
layer
.
v_head_dim
:]
c_kv
=
kv_cache
[:,
:,
:
layer
.
v_head_dim
]
k_rope_cache
=
k_rope
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
,
)
c_kv_cache
=
c_kv
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
o
=
flash_attn_with_kvcache
(
q
=
q_rope
,
k_cache
=
k_rope_cache
,
v_cache
=
c_kv_cache
,
qv
=
q_nope
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
forward_decode
(
self
,
...
...
@@ -184,24 +232,29 @@ class FlashAttentionBackend(AttentionBackend):
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention using precomputed metadata."""
# Save KV cache if needed
if
k
is
not
None
and
v
is
not
None
and
save_kv_cache
:
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
cache_loc
=
(
forward_batch
.
out_cache_loc
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
# Get KV cache
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
# Use precomputed metadata
metadata
=
self
.
forward_metadata
# Pre-reshape query tensor
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
...
...
@@ -210,33 +263,79 @@ class FlashAttentionBackend(AttentionBackend):
if
layer
.
sliding_window_size
is
not
None
else
(
-
1
,
-
1
)
)
# Run attention with precomputed values
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
page_table
=
metadata
.
page_table
o
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
if
not
self
.
use_mla
:
# Do multi-head attention
# Get KV cache
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
key_cache
,
value_cache
=
kv_cache
[
0
],
kv_cache
[
1
]
key_cache
=
key_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
# Pre-reshape query tensor
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
# Run attention with precomputed values
o
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
)
else
:
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_rope
=
kv_cache
[:,
:,
layer
.
v_head_dim
:]
c_kv
=
kv_cache
[:,
:,
:
layer
.
v_head_dim
]
k_rope_cache
=
k_rope
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
,
)
c_kv_cache
=
c_kv
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
o
=
flash_attn_with_kvcache
(
q
=
q_rope
,
k_cache
=
k_rope_cache
,
v_cache
=
c_kv_cache
,
qv
=
q_nope
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens_int32
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
v_descale
=
layer
.
v_scale
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Initialize CUDA graph state for the attention backend.
...
...
@@ -286,7 +385,6 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][
req_pool_indices
,
:
]
if
forward_mode
==
ForwardMode
.
DECODE
:
# Precompute cumulative sequence lengths
metadata
.
cu_seqlens_q
=
torch
.
arange
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
20c90be2
...
...
@@ -230,6 +230,10 @@ class ModelRunner:
elif
server_args
.
enable_flashmla
:
logger
.
info
(
"MLA optimization is turned on. Use flashmla decode."
)
server_args
.
attention_backend
=
"flashmla"
elif
server_args
.
attention_backend
==
"fa3"
:
logger
.
info
(
f
"MLA optimization is turned on. Use flash attention 3 backend."
)
else
:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
server_args
.
attention_backend
=
"triton"
...
...
@@ -879,7 +883,7 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`."
)
logger
.
warning
(
"FlashAttention v3 Backend is in Beta. Multimodal,
Page > 1, FP8, MLA
and Speculative Decoding are not supported."
"FlashAttention v3 Backend is in Beta. Multimodal,
FP8,
and Speculative Decoding are not supported."
)
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
20c90be2
...
...
@@ -655,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
flashinfer_mla_disable_ragged
=
global_server_args_dict
[
"flashinfer_mla_disable_ragged"
]
self
.
attention_backend
=
global_server_args_dict
[
"attention_backend"
]
self
.
rocm_fused_decode_mla
=
os
.
getenv
(
"SGLANG_ROCM_FUSED_DECODE_MLA"
)
==
"1"
def
no_absorb
(
self
,
forward_batch
:
ForwardBatch
)
->
bool
:
...
...
@@ -667,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module):
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
==
0
)
elif
self
.
attention_backend
==
"fa3"
:
# Flash Attention: Keep absorbing for all extend/decode
return
False
else
:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment