Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6b6e7487
Unverified
Commit
6b6e7487
authored
Apr 23, 2025
by
Ke Bao
Committed by
GitHub
Apr 22, 2025
Browse files
Remove q concat in FA3 backend for DeepSeek decode (#5638)
parent
91732486
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
9 deletions
+40
-9
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-0
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+22
-6
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+8
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-2
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
6b6e7487
...
...
@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
):
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
...
...
@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
layer
,
forward_batch
,
save_kv_cache
=
save_kv_cache
,
**
kwargs
,
)
else
:
return
self
.
forward_extend
(
...
...
@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
layer
,
forward_batch
,
save_kv_cache
=
save_kv_cache
,
**
kwargs
,
)
def
forward_decode
(
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
6b6e7487
...
...
@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache
=
c_kv
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
if
q_rope
is
not
None
:
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
else
:
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
result
=
flash_attn_with_kvcache
(
q
=
q_rope
,
...
...
@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend):
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
if
q_rope
is
not
None
:
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
else
:
q_all
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_all
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
q_all
[:,
:,
layer
.
v_head_dim
:]
max_seqlen_q
=
metadata
.
max_seq_len_q
result
=
flash_attn_with_kvcache
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
6b6e7487
...
...
@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
v
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
**
kwargs
,
):
if
k
is
not
None
:
# For cross-layer sharing, kv can be None
...
...
@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
return
forward_batch
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
forward_batch
,
save_kv_cache
q
,
k
,
v
,
self
,
forward_batch
,
save_kv_cache
,
**
kwargs
,
)
python/sglang/srt/models/deepseek_v2.py
View file @
6b6e7487
...
...
@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
attn_output
=
self
.
attn_mqa
(
q
,
k
,
k_nope
,
forward_batch
)
if
self
.
attention_backend
==
"fa3"
:
attn_output
=
self
.
attn_mqa
(
q_nope_out
,
k
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
)
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
attn_output
=
self
.
attn_mqa
(
q
,
k
,
k_nope
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
use_deep_gemm_bmm
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment