Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4793ec7d
"vscode:/vscode.git/clone" did not exist on "ad9d7ce4763f8fb2a9e620bff017830c26086c36"
Unverified
Commit
4793ec7d
authored
Oct 24, 2025
by
Yongfei Xu
Committed by
GitHub
Oct 23, 2025
Browse files
Opt MHA chunked prefix: merge prefix and extend kv cache to run mha once (#10953)
parent
92009bd2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
361 additions
and
51 deletions
+361
-51
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+12
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+18
-10
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+78
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+82
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+35
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+136
-39
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
4793ec7d
...
@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend):
)
)
else
:
else
:
# MHA for extend part of sequence without attending prefix kv cache
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k
=
(
metadata
.
cu_seqlens_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
cu_seqlens_k
)
max_seqlen_k
=
(
metadata
.
max_seq_len_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
max_seq_len_k
)
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_
q
,
cu_seqlens_k
=
cu_seqlens_
k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq
_
len_
q
,
max_seqlen_k
=
max_seqlen_
k
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
4793ec7d
...
@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
...
@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
# Buffers and wrappers
# Buffers and wrappers
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
workspace_buffer
=
attn_backend
.
workspace_buffer
self
.
workspace_buffer
=
attn_backend
.
workspace_buffer
self
.
fmha_backend
=
attn_backend
.
fmha_backend
self
.
fmha_backend
=
attn_backend
.
fmha_backend
...
@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
...
@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
)
)
# ragged prefill
# ragged prefill
if
not
disable_flashinfer_ragged
:
if
not
disable_flashinfer_ragged
:
kv_indptr
=
(
qo_indptr
if
not
forward_batch
.
mha_one_shot
else
self
.
kv_indptr
[:
bs
+
1
]
)
self
.
ragged_wrapper
.
begin_forward
(
self
.
ragged_wrapper
.
begin_forward
(
qo_indptr
=
qo_indptr
,
qo_indptr
=
qo_indptr
,
kv_indptr
=
qo
_indptr
,
kv_indptr
=
kv
_indptr
,
num_qo_heads
=
self
.
num_local_heads
,
num_qo_heads
=
self
.
num_local_heads
,
num_kv_heads
=
self
.
num_local_heads
,
num_kv_heads
=
self
.
num_local_heads
,
head_dim_qk
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
head_dim_qk
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
...
@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
...
@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
chunk_idx
=
forward_batch
.
prefix_chunk_idx
chunk_idx
=
forward_batch
.
prefix_chunk_idx
assert
chunk_idx
>=
0
assert
chunk_idx
>=
0
wrapper
=
self
.
chunk_ragged_wrappers
[
chunk_idx
]
wrapper
=
self
.
chunk_ragged_wrappers
[
chunk_idx
]
o
1
,
s1
=
wrapper
.
forward_return_lse
(
o
=
wrapper
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
...
@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
...
@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
logits_soft_cap
=
logits_soft_cap
,
logits_soft_cap
=
logits_soft_cap
,
)
)
else
:
else
:
o1
,
s1
=
self
.
ragged_wrapper
.
forward_return_lse
(
forward
=
(
self
.
ragged_wrapper
.
forward_return_lse
if
forward_batch
.
mha_return_lse
else
self
.
ragged_wrapper
.
forward
)
o
=
forward
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
...
@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
...
@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
logits_soft_cap
=
logits_soft_cap
,
)
)
return
o
return
o1
,
s1
class
FlashInferMLAAttnBackend
(
AttentionBackend
):
class
FlashInferMLAAttnBackend
(
AttentionBackend
):
...
@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
(
if
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
any
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
extend_prefix_lens_cpu
and
forward_batch
.
mha_return_lse
):
# MHA Chunk
):
# MHA Chunk
assert
self
.
enable_chunk_kv
assert
self
.
enable_chunk_kv
assert
q_rope
is
None
assert
q_rope
is
None
assert
k_rope
is
None
assert
k_rope
is
None
o1
,
s1
=
self
.
mha_chunk_kv_cache
.
forward
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
self
.
mha_chunk_kv_cache
.
forward
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
o1
,
s1
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
logits_soft_cap
=
layer
.
logit_cap
logits_soft_cap
=
layer
.
logit_cap
...
...
python/sglang/srt/layers/attention/utils.py
View file @
4793ec7d
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
...
@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
data
//
PAGED_SIZE
,
data
//
PAGED_SIZE
,
mask
=
mask_out
,
mask
=
mask_out
,
)
)
@
triton
.
jit
def
concat_and_cast_mha_k_kernel
(
k_ptr
,
k_nope_ptr
,
k_rope_ptr
,
head_cnt
:
tl
.
constexpr
,
k_stride0
:
tl
.
constexpr
,
k_stride1
:
tl
.
constexpr
,
nope_stride0
:
tl
.
constexpr
,
nope_stride1
:
tl
.
constexpr
,
rope_stride0
:
tl
.
constexpr
,
nope_dim
:
tl
.
constexpr
,
rope_dim
:
tl
.
constexpr
,
):
pid_loc
=
tl
.
program_id
(
0
)
head_range
=
tl
.
arange
(
0
,
head_cnt
)
k_head_ptr
=
k_ptr
+
pid_loc
*
k_stride0
+
head_range
[:,
None
]
*
k_stride1
nope_offs
=
tl
.
arange
(
0
,
nope_dim
)
src_nope_ptr
=
(
k_nope_ptr
+
pid_loc
*
nope_stride0
+
head_range
[:,
None
]
*
nope_stride1
+
nope_offs
[
None
,
:]
)
dst_nope_ptr
=
k_head_ptr
+
nope_offs
[
None
,
:]
src_nope
=
tl
.
load
(
src_nope_ptr
)
tl
.
store
(
dst_nope_ptr
,
src_nope
)
rope_offs
=
tl
.
arange
(
0
,
rope_dim
)
src_rope_ptr
=
k_rope_ptr
+
pid_loc
*
rope_stride0
+
rope_offs
[
None
,
:]
dst_rope_ptr
=
k_head_ptr
+
nope_dim
+
rope_offs
[
None
,
:]
src_rope
=
tl
.
load
(
src_rope_ptr
)
tl
.
store
(
dst_rope_ptr
,
src_rope
)
def
concat_and_cast_mha_k_triton
(
k
:
torch
.
Tensor
,
k_nope
:
torch
.
Tensor
,
k_rope
:
torch
.
Tensor
,
):
# The source data type will be implicitly converted to the target data type.
assert
(
len
(
k
.
shape
)
==
3
and
len
(
k_nope
.
shape
)
==
3
and
len
(
k_rope
.
shape
)
==
3
),
f
"shape should be 3d, but got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
assert
(
k
.
shape
[
0
]
==
k_nope
.
shape
[
0
]
and
k
.
shape
[
0
]
==
k_rope
.
shape
[
0
]
),
f
"invalid shape, got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
assert
(
k
.
shape
[
1
]
==
k_nope
.
shape
[
1
]
and
1
==
k_rope
.
shape
[
1
]
),
f
"invalid shape, got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
assert
(
k
.
shape
[
-
1
]
==
k_nope
.
shape
[
-
1
]
+
k_rope
.
shape
[
-
1
]
),
f
"invalid shape, got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
nope_dim
=
k_nope
.
shape
[
-
1
]
rope_dim
=
k_rope
.
shape
[
-
1
]
grid
=
(
k
.
shape
[
0
],)
concat_and_cast_mha_k_kernel
[
grid
](
k
,
k_nope
,
k_rope
,
k
.
shape
[
1
],
k
.
stride
(
0
),
k
.
stride
(
1
),
k_nope
.
stride
(
0
),
k_nope
.
stride
(
1
),
k_rope
.
stride
(
0
),
nope_dim
,
rope_dim
,
)
python/sglang/srt/mem_cache/memory_pool.py
View file @
4793ec7d
...
@@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton(
...
@@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton(
)
)
@
triton
.
jit
def
get_mla_kv_buffer_kernel
(
kv_buffer_ptr
,
cache_k_nope_ptr
,
cache_k_rope_ptr
,
loc_ptr
,
buffer_stride
:
tl
.
constexpr
,
nope_stride
:
tl
.
constexpr
,
rope_stride
:
tl
.
constexpr
,
nope_dim
:
tl
.
constexpr
,
rope_dim
:
tl
.
constexpr
,
):
pid_loc
=
tl
.
program_id
(
0
)
loc
=
tl
.
load
(
loc_ptr
+
pid_loc
)
loc_src_ptr
=
kv_buffer_ptr
+
loc
*
buffer_stride
nope_offs
=
tl
.
arange
(
0
,
nope_dim
)
nope_src_ptr
=
loc_src_ptr
+
nope_offs
nope_src
=
tl
.
load
(
nope_src_ptr
)
tl
.
store
(
cache_k_nope_ptr
+
pid_loc
*
nope_stride
+
nope_offs
,
nope_src
,
)
rope_offs
=
tl
.
arange
(
0
,
rope_dim
)
rope_src_ptr
=
loc_src_ptr
+
nope_dim
+
rope_offs
rope_src
=
tl
.
load
(
rope_src_ptr
)
tl
.
store
(
cache_k_rope_ptr
+
pid_loc
*
rope_stride
+
rope_offs
,
rope_src
,
)
def
get_mla_kv_buffer_triton
(
kv_buffer
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k_nope
:
torch
.
Tensor
,
cache_k_rope
:
torch
.
Tensor
,
):
# The source data type will be implicitly converted to the target data type.
nope_dim
=
cache_k_nope
.
shape
[
-
1
]
# 512
rope_dim
=
cache_k_rope
.
shape
[
-
1
]
# 64
n_loc
=
loc
.
numel
()
grid
=
(
n_loc
,)
get_mla_kv_buffer_kernel
[
grid
](
kv_buffer
,
cache_k_nope
,
cache_k_rope
,
loc
,
kv_buffer
.
stride
(
0
),
cache_k_nope
.
stride
(
0
),
cache_k_rope
.
stride
(
0
),
nope_dim
,
rope_dim
,
)
class
MLATokenToKVPool
(
KVCache
):
class
MLATokenToKVPool
(
KVCache
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache):
...
@@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache):
cache_k_rope
,
cache_k_rope
,
)
)
def
get_mla_kv_buffer
(
self
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
dst_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
# get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
layer_id
=
layer
.
layer_id
kv_buffer
=
self
.
get_key_buffer
(
layer_id
)
dst_dtype
=
dst_dtype
or
self
.
dtype
cache_k_nope
=
torch
.
empty
(
(
loc
.
shape
[
0
],
1
,
self
.
kv_lora_rank
),
dtype
=
dst_dtype
,
device
=
kv_buffer
.
device
,
)
cache_k_rope
=
torch
.
empty
(
(
loc
.
shape
[
0
],
1
,
self
.
qk_rope_head_dim
),
dtype
=
dst_dtype
,
device
=
kv_buffer
.
device
,
)
get_mla_kv_buffer_triton
(
kv_buffer
,
loc
,
cache_k_nope
,
cache_k_rope
)
return
cache_k_nope
,
cache_k_rope
def
get_cpu_copy
(
self
,
indices
):
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
kv_cache_cpu
=
[]
kv_cache_cpu
=
[]
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
4793ec7d
...
@@ -39,6 +39,7 @@ import triton
...
@@ -39,6 +39,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
DpPaddingMode
,
DpPaddingMode
,
get_attention_dp_rank
,
get_attention_dp_rank
,
...
@@ -250,6 +251,8 @@ class ForwardBatch:
...
@@ -250,6 +251,8 @@ class ForwardBatch:
# For MLA chunked prefix cache used in chunked prefill
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether lse needs to be returned
# Tell attention backend whether lse needs to be returned
mha_return_lse
:
Optional
[
bool
]
=
None
mha_return_lse
:
Optional
[
bool
]
=
None
mha_one_shot_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
mha_one_shot
:
Optional
[
bool
]
=
None
# For multimodal
# For multimodal
mm_inputs
:
Optional
[
List
[
MultimodalInputs
]]
=
None
mm_inputs
:
Optional
[
List
[
MultimodalInputs
]]
=
None
...
@@ -863,6 +866,10 @@ class ForwardBatch:
...
@@ -863,6 +866,10 @@ class ForwardBatch:
self
.
token_to_kv_pool
,
MLATokenToKVPool
self
.
token_to_kv_pool
,
MLATokenToKVPool
),
"Currently chunked prefix cache can only be used by Deepseek models"
),
"Currently chunked prefix cache can only be used by Deepseek models"
if
not
any
(
self
.
extend_prefix_lens_cpu
):
self
.
num_prefix_chunks
=
0
return
if
self
.
prefix_chunk_len
is
not
None
:
if
self
.
prefix_chunk_len
is
not
None
:
# Chunked kv cache info already prepared by prior modules
# Chunked kv cache info already prepared by prior modules
return
return
...
@@ -917,6 +924,34 @@ class ForwardBatch:
...
@@ -917,6 +924,34 @@ class ForwardBatch:
def
can_run_tbo
(
self
):
def
can_run_tbo
(
self
):
return
self
.
tbo_split_seq_index
is
not
None
return
self
.
tbo_split_seq_index
is
not
None
def
fetch_mha_one_shot_kv_indices
(
self
):
if
self
.
mha_one_shot_kv_indices
is
not
None
:
return
self
.
mha_one_shot_kv_indices
batch_size
=
self
.
batch_size
paged_kernel_lens_sum
=
sum
(
self
.
seq_lens_cpu
)
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
req_pool_indices
.
device
,
)
kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
req_pool_indices
.
device
,
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
,
dim
=
0
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
self
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
self
.
mha_one_shot_kv_indices
=
kv_indices
return
kv_indices
def
enable_num_token_non_padded
(
server_args
):
def
enable_num_token_non_padded
(
server_args
):
return
get_moe_expert_parallel_world_size
()
>
1
return
get_moe_expert_parallel_world_size
()
>
1
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
4793ec7d
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
is_mla_preprocess_enabled
,
is_mla_preprocess_enabled
,
)
)
from
sglang.srt.layers.attention.nsa.nsa_indexer
import
Indexer
from
sglang.srt.layers.attention.nsa.nsa_indexer
import
Indexer
from
sglang.srt.layers.attention.utils
import
concat_and_cast_mha_k_triton
from
sglang.srt.layers.communicator
import
(
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerCommunicator
,
LayerScatterModes
,
LayerScatterModes
,
...
@@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum):
...
@@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV
=
auto
()
MHA_CHUNKED_KV
=
auto
()
# Use multi-head attention, execute the MHA for prefix and extended kv in one shot
# when the sequence lengths are below the threshold.
MHA_ONE_SHOT
=
auto
()
# Use MLA but with fused RoPE
# Use MLA but with fused RoPE
MLA_FUSED_ROPE
=
auto
()
MLA_FUSED_ROPE
=
auto
()
...
@@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch):
...
@@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch):
)
)
def
_support_mha_one_shot
(
attn
:
DeepseekV2AttentionMLA
,
forward_batch
,
backend_name
):
attn_supported
=
backend_name
in
[
"fa3"
,
"flashinfer"
,
"flashmla"
]
sum_seq_lens
=
(
sum
(
forward_batch
.
seq_lens_cpu
)
if
forward_batch
.
seq_lens_cpu
is
not
None
else
0
)
return
attn_supported
and
sum_seq_lens
<=
forward_batch
.
get_max_chunk_capacity
()
def
_handle_attention_backend
(
def
_handle_attention_backend
(
attn
:
DeepseekV2AttentionMLA
,
forward_batch
,
backend_name
attn
:
DeepseekV2AttentionMLA
,
forward_batch
,
backend_name
):
):
...
@@ -325,6 +338,8 @@ def _handle_attention_backend(
...
@@ -325,6 +338,8 @@ def _handle_attention_backend(
or
sum_extend_prefix_lens
==
0
or
sum_extend_prefix_lens
==
0
)
)
):
):
if
_support_mha_one_shot
(
attn
,
forward_batch
,
backend_name
):
return
AttnForwardMethod
.
MHA_ONE_SHOT
return
AttnForwardMethod
.
MHA_CHUNKED_KV
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
else
:
return
_dispatch_mla_subtype
(
attn
,
forward_batch
)
return
_dispatch_mla_subtype
(
attn
,
forward_batch
)
...
@@ -1062,6 +1077,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1062,6 +1077,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
kv_cache_dtype
=
get_global_server_args
().
kv_cache_dtype
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
if
rope_scaling
:
if
rope_scaling
:
...
@@ -1359,6 +1375,10 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1359,6 +1375,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state
=
self
.
forward_normal_chunked_kv_prepare
(
inner_state
=
self
.
forward_normal_chunked_kv_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_ONE_SHOT
:
inner_state
=
self
.
forward_normal_one_shot_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
if
not
self
.
is_mla_preprocess_enabled
:
if
not
self
.
is_mla_preprocess_enabled
:
inner_state
=
self
.
forward_absorb_prepare
(
inner_state
=
self
.
forward_absorb_prepare
(
...
@@ -1410,6 +1430,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1410,6 +1430,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return
self
.
forward_normal_core
(
*
inner_state
)
return
self
.
forward_normal_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv_core
(
*
inner_state
)
return
self
.
forward_normal_chunked_kv_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_ONE_SHOT
:
return
self
.
forward_normal_one_shot_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb_core
(
*
inner_state
)
return
self
.
forward_absorb_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
NPU_MLA_SPARSE
:
elif
attn_forward_method
==
AttnForwardMethod
.
NPU_MLA_SPARSE
:
...
@@ -1444,41 +1466,24 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1444,41 +1466,24 @@ class DeepseekV2AttentionMLA(nn.Module):
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
)
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
=
kv
[...,
:
self
.
qk_nope_head_dim
]
v
=
kv
[...,
self
.
qk_nope_head_dim
:]
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
self
.
_set_mla_kv_buffer
(
latent_cache
,
kv_a
,
k_pe
,
forward_batch
)
if
(
if
(
_is_cuda
forward_batch
.
mha_one_shot
and
(
self
.
num_local_heads
==
128
)
and
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
!=
0
and
(
self
.
qk_nope_head_dim
==
128
)
and
(
self
.
qk_rope_head_dim
==
64
)
):
):
concat_mla_k
(
k
=
k
,
k_nope
=
k_nope
,
k_rope
=
k_pe
)
kv_a
,
k_pe
=
self
.
_get_mla_kv_buffer
(
else
:
forward_batch
.
fetch_mha_one_shot_kv_indices
(),
q
.
dtype
,
forward_batch
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
if
not
_is_npu
:
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
else
:
# To reduce a time-costing split operation
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
)
)
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
=
kv
[...,
:
self
.
qk_nope_head_dim
]
v
=
kv
[...,
self
.
qk_nope_head_dim
:]
k
=
self
.
_concat_and_cast_mha_k
(
k_nope
,
k_pe
,
forward_batch
)
return
q
,
k
,
v
,
forward_batch
return
q
,
k
,
v
,
forward_batch
def
forward_normal_core
(
self
,
q
,
k
,
v
,
forward_batch
):
def
forward_normal_core
(
self
,
q
,
k
,
v
,
forward_batch
):
...
@@ -2288,20 +2293,11 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -2288,20 +2293,11 @@ class DeepseekV2AttentionMLA(nn.Module):
for
i
in
range
(
forward_batch
.
num_prefix_chunks
):
for
i
in
range
(
forward_batch
.
num_prefix_chunks
):
forward_batch
.
set_prefix_chunk_idx
(
i
)
forward_batch
.
set_prefix_chunk_idx
(
i
)
kv_indices
=
forward_batch
.
prefix_chunk_kv_indices
[
i
]
# Fetch latent cache from memory pool with precomputed chunked kv indices
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
kv_a_normed
,
k_pe
=
self
.
_get_mla_kv_buffer
(
self
.
attn_mha
.
layer_id
kv_indices
,
q
.
dtype
,
forward_batch
)
latent_cache
=
(
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
.
contiguous
()
.
to
(
q
.
dtype
)
)
kv_a_normed
,
k_pe
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
kv_a_normed
=
kv_a_normed
.
squeeze
(
1
).
contiguous
()
kv
=
self
.
kv_b_proj
(
kv_a_normed
)[
0
]
kv
=
self
.
kv_b_proj
(
kv_a_normed
)[
0
]
kv
=
kv
.
view
(
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
...
@@ -2376,6 +2372,107 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -2376,6 +2372,107 @@ class DeepseekV2AttentionMLA(nn.Module):
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
def
forward_normal_one_shot_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
):
forward_batch
.
mha_one_shot
=
True
return
self
.
forward_normal_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
def
forward_normal_one_shot_core
(
self
,
q
,
k
,
v
,
forward_batch
):
has_extend_prefix
=
any
(
forward_batch
.
extend_prefix_lens_cpu
)
# Only initialize the info once
if
has_extend_prefix
and
forward_batch
.
num_prefix_chunks
is
None
:
forward_batch
.
num_prefix_chunks
=
0
if
hasattr
(
forward_batch
.
attn_backend
,
"init_mha_chunk_metadata"
):
forward_batch
.
attn_backend
.
init_mha_chunk_metadata
(
forward_batch
)
forward_batch
.
mha_return_lse
=
False
# Do mha for extended part without prefix
forward_batch
.
set_attn_attend_prefix_cache
(
False
)
return
self
.
forward_normal_core
(
q
,
k
,
v
,
forward_batch
)
def
_set_mla_kv_buffer
(
self
,
latent_cache
:
torch
.
Tensor
,
kv_a
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
if
_is_cuda
:
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
)
elif
_is_npu
:
# To reduce a time-costing split operation
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
)
else
:
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
def
_get_mla_kv_buffer
(
self
,
kv_indices
:
torch
.
Tensor
,
dst_dtype
:
torch
.
dtype
,
forward_batch
:
ForwardBatch
,
):
if
_is_cuda
:
kv_a
,
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_mla_kv_buffer
(
self
.
attn_mha
,
kv_indices
,
dst_dtype
)
kv_a
=
kv_a
.
squeeze
(
1
)
else
:
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
self
.
attn_mha
.
layer_id
)
latent_cache
=
latent_cache_buf
[
kv_indices
].
contiguous
().
to
(
dst_dtype
)
kv_a
,
k_pe
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_a
=
kv_a
.
squeeze
(
1
).
contiguous
()
return
kv_a
,
k_pe
def
_concat_and_cast_mha_k
(
self
,
k_nope
,
k_pe
,
forward_batch
):
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
k_shape
=
(
k_nope
.
shape
[
0
],
self
.
num_local_heads
,
self
.
qk_head_dim
)
if
(
_is_cuda
and
(
self
.
num_local_heads
==
128
)
and
(
self
.
qk_nope_head_dim
==
128
)
and
(
self
.
qk_rope_head_dim
==
64
)
):
k
=
k_nope
.
new_empty
(
*
k_shape
)
concat_mla_k
(
k
=
k
,
k_nope
=
k_nope
,
k_rope
=
k_pe
)
elif
_is_cuda
:
# fa3 mha support fp8 inputs
if
(
self
.
current_attention_backend
==
"fa3"
and
self
.
kv_cache_dtype
!=
"auto"
):
attn_dtype
=
forward_batch
.
token_to_kv_pool
.
dtype
else
:
attn_dtype
=
k_nope
.
dtype
k
=
k_nope
.
new_empty
(
*
k_shape
,
dtype
=
attn_dtype
)
concat_and_cast_mha_k_triton
(
k
,
k_nope
,
k_pe
)
else
:
k
=
k_nope
.
new_empty
(
*
k_shape
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
return
k
@
staticmethod
@
staticmethod
def
_get_q_b_proj_quant_config
(
quant_config
):
def
_get_q_b_proj_quant_config
(
quant_config
):
if
get_bool_env_var
(
"SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"
):
if
get_bool_env_var
(
"SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment