Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4793ec7d
Unverified
Commit
4793ec7d
authored
Oct 24, 2025
by
Yongfei Xu
Committed by
GitHub
Oct 23, 2025
Browse files
Opt MHA chunked prefix: merge prefix and extend kv cache to run mha once (#10953)
parent
92009bd2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
361 additions
and
51 deletions
+361
-51
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+12
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+18
-10
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+78
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+82
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+35
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+136
-39
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
4793ec7d
...
...
@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend):
)
else
:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k
=
(
metadata
.
cu_seqlens_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
cu_seqlens_k
)
max_seqlen_k
=
(
metadata
.
max_seq_len_q
if
not
forward_batch
.
mha_one_shot
else
metadata
.
max_seq_len_k
)
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_
q
,
cu_seqlens_k
=
cu_seqlens_
k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq
_
len_
q
,
max_seqlen_k
=
max_seqlen_
k
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
4793ec7d
...
...
@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
# Buffers and wrappers
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
workspace_buffer
=
attn_backend
.
workspace_buffer
self
.
fmha_backend
=
attn_backend
.
fmha_backend
...
...
@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
)
# ragged prefill
if
not
disable_flashinfer_ragged
:
kv_indptr
=
(
qo_indptr
if
not
forward_batch
.
mha_one_shot
else
self
.
kv_indptr
[:
bs
+
1
]
)
self
.
ragged_wrapper
.
begin_forward
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
qo
_indptr
,
kv_indptr
=
kv
_indptr
,
num_qo_heads
=
self
.
num_local_heads
,
num_kv_heads
=
self
.
num_local_heads
,
head_dim_qk
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
...
...
@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
chunk_idx
=
forward_batch
.
prefix_chunk_idx
assert
chunk_idx
>=
0
wrapper
=
self
.
chunk_ragged_wrappers
[
chunk_idx
]
o
1
,
s1
=
wrapper
.
forward_return_lse
(
o
=
wrapper
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
...
...
@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
logits_soft_cap
=
logits_soft_cap
,
)
else
:
o1
,
s1
=
self
.
ragged_wrapper
.
forward_return_lse
(
forward
=
(
self
.
ragged_wrapper
.
forward_return_lse
if
forward_batch
.
mha_return_lse
else
self
.
ragged_wrapper
.
forward
)
o
=
forward
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
),
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
),
...
...
@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
)
return
o1
,
s1
return
o
class
FlashInferMLAAttnBackend
(
AttentionBackend
):
...
...
@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
forward_batch
.
mha_return_lse
if
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
any
(
forward_batch
.
extend_prefix_lens_cpu
):
# MHA Chunk
assert
self
.
enable_chunk_kv
assert
q_rope
is
None
assert
k_rope
is
None
o1
,
s1
=
self
.
mha_chunk_kv_cache
.
forward
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
o1
,
s1
return
self
.
mha_chunk_kv_cache
.
forward
(
q
,
k
,
v
,
layer
,
forward_batch
)
cache_loc
=
forward_batch
.
out_cache_loc
logits_soft_cap
=
layer
.
logit_cap
...
...
python/sglang/srt/layers/attention/utils.py
View file @
4793ec7d
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
data
//
PAGED_SIZE
,
mask
=
mask_out
,
)
@
triton
.
jit
def
concat_and_cast_mha_k_kernel
(
k_ptr
,
k_nope_ptr
,
k_rope_ptr
,
head_cnt
:
tl
.
constexpr
,
k_stride0
:
tl
.
constexpr
,
k_stride1
:
tl
.
constexpr
,
nope_stride0
:
tl
.
constexpr
,
nope_stride1
:
tl
.
constexpr
,
rope_stride0
:
tl
.
constexpr
,
nope_dim
:
tl
.
constexpr
,
rope_dim
:
tl
.
constexpr
,
):
pid_loc
=
tl
.
program_id
(
0
)
head_range
=
tl
.
arange
(
0
,
head_cnt
)
k_head_ptr
=
k_ptr
+
pid_loc
*
k_stride0
+
head_range
[:,
None
]
*
k_stride1
nope_offs
=
tl
.
arange
(
0
,
nope_dim
)
src_nope_ptr
=
(
k_nope_ptr
+
pid_loc
*
nope_stride0
+
head_range
[:,
None
]
*
nope_stride1
+
nope_offs
[
None
,
:]
)
dst_nope_ptr
=
k_head_ptr
+
nope_offs
[
None
,
:]
src_nope
=
tl
.
load
(
src_nope_ptr
)
tl
.
store
(
dst_nope_ptr
,
src_nope
)
rope_offs
=
tl
.
arange
(
0
,
rope_dim
)
src_rope_ptr
=
k_rope_ptr
+
pid_loc
*
rope_stride0
+
rope_offs
[
None
,
:]
dst_rope_ptr
=
k_head_ptr
+
nope_dim
+
rope_offs
[
None
,
:]
src_rope
=
tl
.
load
(
src_rope_ptr
)
tl
.
store
(
dst_rope_ptr
,
src_rope
)
def
concat_and_cast_mha_k_triton
(
k
:
torch
.
Tensor
,
k_nope
:
torch
.
Tensor
,
k_rope
:
torch
.
Tensor
,
):
# The source data type will be implicitly converted to the target data type.
assert
(
len
(
k
.
shape
)
==
3
and
len
(
k_nope
.
shape
)
==
3
and
len
(
k_rope
.
shape
)
==
3
),
f
"shape should be 3d, but got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
assert
(
k
.
shape
[
0
]
==
k_nope
.
shape
[
0
]
and
k
.
shape
[
0
]
==
k_rope
.
shape
[
0
]
),
f
"invalid shape, got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
assert
(
k
.
shape
[
1
]
==
k_nope
.
shape
[
1
]
and
1
==
k_rope
.
shape
[
1
]
),
f
"invalid shape, got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
assert
(
k
.
shape
[
-
1
]
==
k_nope
.
shape
[
-
1
]
+
k_rope
.
shape
[
-
1
]
),
f
"invalid shape, got
{
k
.
shape
=
}
,
{
k_nope
.
shape
=
}
,
{
k_rope
.
shape
=
}
"
nope_dim
=
k_nope
.
shape
[
-
1
]
rope_dim
=
k_rope
.
shape
[
-
1
]
grid
=
(
k
.
shape
[
0
],)
concat_and_cast_mha_k_kernel
[
grid
](
k
,
k_nope
,
k_rope
,
k
.
shape
[
1
],
k
.
stride
(
0
),
k
.
stride
(
1
),
k_nope
.
stride
(
0
),
k_nope
.
stride
(
1
),
k_rope
.
stride
(
0
),
nope_dim
,
rope_dim
,
)
python/sglang/srt/mem_cache/memory_pool.py
View file @
4793ec7d
...
...
@@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton(
)
@
triton
.
jit
def
get_mla_kv_buffer_kernel
(
kv_buffer_ptr
,
cache_k_nope_ptr
,
cache_k_rope_ptr
,
loc_ptr
,
buffer_stride
:
tl
.
constexpr
,
nope_stride
:
tl
.
constexpr
,
rope_stride
:
tl
.
constexpr
,
nope_dim
:
tl
.
constexpr
,
rope_dim
:
tl
.
constexpr
,
):
pid_loc
=
tl
.
program_id
(
0
)
loc
=
tl
.
load
(
loc_ptr
+
pid_loc
)
loc_src_ptr
=
kv_buffer_ptr
+
loc
*
buffer_stride
nope_offs
=
tl
.
arange
(
0
,
nope_dim
)
nope_src_ptr
=
loc_src_ptr
+
nope_offs
nope_src
=
tl
.
load
(
nope_src_ptr
)
tl
.
store
(
cache_k_nope_ptr
+
pid_loc
*
nope_stride
+
nope_offs
,
nope_src
,
)
rope_offs
=
tl
.
arange
(
0
,
rope_dim
)
rope_src_ptr
=
loc_src_ptr
+
nope_dim
+
rope_offs
rope_src
=
tl
.
load
(
rope_src_ptr
)
tl
.
store
(
cache_k_rope_ptr
+
pid_loc
*
rope_stride
+
rope_offs
,
rope_src
,
)
def
get_mla_kv_buffer_triton
(
kv_buffer
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k_nope
:
torch
.
Tensor
,
cache_k_rope
:
torch
.
Tensor
,
):
# The source data type will be implicitly converted to the target data type.
nope_dim
=
cache_k_nope
.
shape
[
-
1
]
# 512
rope_dim
=
cache_k_rope
.
shape
[
-
1
]
# 64
n_loc
=
loc
.
numel
()
grid
=
(
n_loc
,)
get_mla_kv_buffer_kernel
[
grid
](
kv_buffer
,
cache_k_nope
,
cache_k_rope
,
loc
,
kv_buffer
.
stride
(
0
),
cache_k_nope
.
stride
(
0
),
cache_k_rope
.
stride
(
0
),
nope_dim
,
rope_dim
,
)
class
MLATokenToKVPool
(
KVCache
):
def
__init__
(
self
,
...
...
@@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache):
cache_k_rope
,
)
def
get_mla_kv_buffer
(
self
,
layer
:
RadixAttention
,
loc
:
torch
.
Tensor
,
dst_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
# get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
layer_id
=
layer
.
layer_id
kv_buffer
=
self
.
get_key_buffer
(
layer_id
)
dst_dtype
=
dst_dtype
or
self
.
dtype
cache_k_nope
=
torch
.
empty
(
(
loc
.
shape
[
0
],
1
,
self
.
kv_lora_rank
),
dtype
=
dst_dtype
,
device
=
kv_buffer
.
device
,
)
cache_k_rope
=
torch
.
empty
(
(
loc
.
shape
[
0
],
1
,
self
.
qk_rope_head_dim
),
dtype
=
dst_dtype
,
device
=
kv_buffer
.
device
,
)
get_mla_kv_buffer_triton
(
kv_buffer
,
loc
,
cache_k_nope
,
cache_k_rope
)
return
cache_k_nope
,
cache_k_rope
def
get_cpu_copy
(
self
,
indices
):
torch
.
cuda
.
synchronize
()
kv_cache_cpu
=
[]
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
4793ec7d
...
...
@@ -39,6 +39,7 @@ import triton
import
triton.language
as
tl
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
(
DpPaddingMode
,
get_attention_dp_rank
,
...
...
@@ -250,6 +251,8 @@ class ForwardBatch:
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether lse needs to be returned
mha_return_lse
:
Optional
[
bool
]
=
None
mha_one_shot_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
mha_one_shot
:
Optional
[
bool
]
=
None
# For multimodal
mm_inputs
:
Optional
[
List
[
MultimodalInputs
]]
=
None
...
...
@@ -863,6 +866,10 @@ class ForwardBatch:
self
.
token_to_kv_pool
,
MLATokenToKVPool
),
"Currently chunked prefix cache can only be used by Deepseek models"
if
not
any
(
self
.
extend_prefix_lens_cpu
):
self
.
num_prefix_chunks
=
0
return
if
self
.
prefix_chunk_len
is
not
None
:
# Chunked kv cache info already prepared by prior modules
return
...
...
@@ -917,6 +924,34 @@ class ForwardBatch:
def
can_run_tbo
(
self
):
return
self
.
tbo_split_seq_index
is
not
None
def
fetch_mha_one_shot_kv_indices
(
self
):
if
self
.
mha_one_shot_kv_indices
is
not
None
:
return
self
.
mha_one_shot_kv_indices
batch_size
=
self
.
batch_size
paged_kernel_lens_sum
=
sum
(
self
.
seq_lens_cpu
)
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
req_pool_indices
.
device
,
)
kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
req_pool_indices
.
device
,
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
,
dim
=
0
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
self
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
self
.
mha_one_shot_kv_indices
=
kv_indices
return
kv_indices
def
enable_num_token_non_padded
(
server_args
):
return
get_moe_expert_parallel_world_size
()
>
1
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
4793ec7d
...
...
@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
is_mla_preprocess_enabled
,
)
from
sglang.srt.layers.attention.nsa.nsa_indexer
import
Indexer
from
sglang.srt.layers.attention.utils
import
concat_and_cast_mha_k_triton
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerScatterModes
,
...
...
@@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV
=
auto
()
# Use multi-head attention, execute the MHA for prefix and extended kv in one shot
# when the sequence lengths are below the threshold.
MHA_ONE_SHOT
=
auto
()
# Use MLA but with fused RoPE
MLA_FUSED_ROPE
=
auto
()
...
...
@@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch):
)
def
_support_mha_one_shot
(
attn
:
DeepseekV2AttentionMLA
,
forward_batch
,
backend_name
):
attn_supported
=
backend_name
in
[
"fa3"
,
"flashinfer"
,
"flashmla"
]
sum_seq_lens
=
(
sum
(
forward_batch
.
seq_lens_cpu
)
if
forward_batch
.
seq_lens_cpu
is
not
None
else
0
)
return
attn_supported
and
sum_seq_lens
<=
forward_batch
.
get_max_chunk_capacity
()
def
_handle_attention_backend
(
attn
:
DeepseekV2AttentionMLA
,
forward_batch
,
backend_name
):
...
...
@@ -325,6 +338,8 @@ def _handle_attention_backend(
or
sum_extend_prefix_lens
==
0
)
):
if
_support_mha_one_shot
(
attn
,
forward_batch
,
backend_name
):
return
AttnForwardMethod
.
MHA_ONE_SHOT
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
return
_dispatch_mla_subtype
(
attn
,
forward_batch
)
...
...
@@ -1062,6 +1077,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
kv_cache_dtype
=
get_global_server_args
().
kv_cache_dtype
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
if
rope_scaling
:
...
...
@@ -1359,6 +1375,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state
=
self
.
forward_normal_chunked_kv_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_ONE_SHOT
:
inner_state
=
self
.
forward_normal_one_shot_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
if
not
self
.
is_mla_preprocess_enabled
:
inner_state
=
self
.
forward_absorb_prepare
(
...
...
@@ -1410,6 +1430,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return
self
.
forward_normal_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_ONE_SHOT
:
return
self
.
forward_normal_one_shot_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
NPU_MLA_SPARSE
:
...
...
@@ -1444,41 +1466,24 @@ class DeepseekV2AttentionMLA(nn.Module):
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
)
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
=
kv
[...,
:
self
.
qk_nope_head_dim
]
v
=
kv
[...,
self
.
qk_nope_head_dim
:]
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
self
.
_set_mla_kv_buffer
(
latent_cache
,
kv_a
,
k_pe
,
forward_batch
)
if
(
_is_cuda
and
(
self
.
num_local_heads
==
128
)
and
(
self
.
qk_nope_head_dim
==
128
)
and
(
self
.
qk_rope_head_dim
==
64
)
forward_batch
.
mha_one_shot
and
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
!=
0
):
concat_mla_k
(
k
=
k
,
k_nope
=
k_nope
,
k_rope
=
k_pe
)
else
:
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
if
not
_is_npu
:
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
else
:
# To reduce a time-costing split operation
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
kv_a
,
k_pe
=
self
.
_get_mla_kv_buffer
(
forward_batch
.
fetch_mha_one_shot_kv_indices
(),
q
.
dtype
,
forward_batch
)
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
=
kv
[...,
:
self
.
qk_nope_head_dim
]
v
=
kv
[...,
self
.
qk_nope_head_dim
:]
k
=
self
.
_concat_and_cast_mha_k
(
k_nope
,
k_pe
,
forward_batch
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_core
(
self
,
q
,
k
,
v
,
forward_batch
):
...
...
@@ -2288,20 +2293,11 @@ class DeepseekV2AttentionMLA(nn.Module):
for
i
in
range
(
forward_batch
.
num_prefix_chunks
):
forward_batch
.
set_prefix_chunk_idx
(
i
)
kv_indices
=
forward_batch
.
prefix_chunk_kv_indices
[
i
]
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
self
.
attn_mha
.
layer_id
)
latent_cache
=
(
latent_cache_buf
[
forward_batch
.
prefix_chunk_kv_indices
[
i
]]
.
contiguous
()
.
to
(
q
.
dtype
)
)
kv_a_normed
,
k_pe
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
kv_a_normed
,
k_pe
=
self
.
_get_mla_kv_buffer
(
kv_indices
,
q
.
dtype
,
forward_batch
)
kv_a_normed
=
kv_a_normed
.
squeeze
(
1
).
contiguous
()
kv
=
self
.
kv_b_proj
(
kv_a_normed
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
...
...
@@ -2376,6 +2372,107 @@ class DeepseekV2AttentionMLA(nn.Module):
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
forward_normal_one_shot_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
):
forward_batch
.
mha_one_shot
=
True
return
self
.
forward_normal_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
def
forward_normal_one_shot_core
(
self
,
q
,
k
,
v
,
forward_batch
):
has_extend_prefix
=
any
(
forward_batch
.
extend_prefix_lens_cpu
)
# Only initialize the info once
if
has_extend_prefix
and
forward_batch
.
num_prefix_chunks
is
None
:
forward_batch
.
num_prefix_chunks
=
0
if
hasattr
(
forward_batch
.
attn_backend
,
"init_mha_chunk_metadata"
):
forward_batch
.
attn_backend
.
init_mha_chunk_metadata
(
forward_batch
)
forward_batch
.
mha_return_lse
=
False
# Do mha for extended part without prefix
forward_batch
.
set_attn_attend_prefix_cache
(
False
)
return
self
.
forward_normal_core
(
q
,
k
,
v
,
forward_batch
)
def
_set_mla_kv_buffer
(
self
,
latent_cache
:
torch
.
Tensor
,
kv_a
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
if
_is_cuda
:
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
)
elif
_is_npu
:
# To reduce a time-costing split operation
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
)
else
:
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
def
_get_mla_kv_buffer
(
self
,
kv_indices
:
torch
.
Tensor
,
dst_dtype
:
torch
.
dtype
,
forward_batch
:
ForwardBatch
,
):
if
_is_cuda
:
kv_a
,
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_mla_kv_buffer
(
self
.
attn_mha
,
kv_indices
,
dst_dtype
)
kv_a
=
kv_a
.
squeeze
(
1
)
else
:
latent_cache_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
self
.
attn_mha
.
layer_id
)
latent_cache
=
latent_cache_buf
[
kv_indices
].
contiguous
().
to
(
dst_dtype
)
kv_a
,
k_pe
=
latent_cache
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_a
=
kv_a
.
squeeze
(
1
).
contiguous
()
return
kv_a
,
k_pe
def
_concat_and_cast_mha_k
(
self
,
k_nope
,
k_pe
,
forward_batch
):
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
k_shape
=
(
k_nope
.
shape
[
0
],
self
.
num_local_heads
,
self
.
qk_head_dim
)
if
(
_is_cuda
and
(
self
.
num_local_heads
==
128
)
and
(
self
.
qk_nope_head_dim
==
128
)
and
(
self
.
qk_rope_head_dim
==
64
)
):
k
=
k_nope
.
new_empty
(
*
k_shape
)
concat_mla_k
(
k
=
k
,
k_nope
=
k_nope
,
k_rope
=
k_pe
)
elif
_is_cuda
:
# fa3 mha support fp8 inputs
if
(
self
.
current_attention_backend
==
"fa3"
and
self
.
kv_cache_dtype
!=
"auto"
):
attn_dtype
=
forward_batch
.
token_to_kv_pool
.
dtype
else
:
attn_dtype
=
k_nope
.
dtype
k
=
k_nope
.
new_empty
(
*
k_shape
,
dtype
=
attn_dtype
)
concat_and_cast_mha_k_triton
(
k
,
k_nope
,
k_pe
)
else
:
k
=
k_nope
.
new_empty
(
*
k_shape
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
return
k
@
staticmethod
def
_get_q_b_proj_quant_config
(
quant_config
):
if
get_bool_env_var
(
"SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment