Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ebd9dbe7
"graphbolt/vscode:/vscode.git/clone" did not exist on "f95e9df31f9acb6cd678fa4f58206875cc783b7c"
Unverified
Commit
ebd9dbe7
authored
Aug 25, 2025
by
Yineng Zhang
Committed by
GitHub
Aug 25, 2025
Browse files
fix: revert #8593 (#9581)
parent
938e986e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
290 deletions
+103
-290
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+71
-89
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+15
-94
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-4
test/srt/test_create_kvindices.py
test/srt/test_create_kvindices.py
+17
-59
test/srt/test_mla_flashinfer.py
test/srt/test_mla_flashinfer.py
+0
-44
No files found.
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
ebd9dbe7
...
@@ -24,7 +24,9 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
...
@@ -24,7 +24,9 @@ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -179,6 +181,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -179,6 +181,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_indptr_decode_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
q_indptr_decode_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
# Parse constants
# Parse constants
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
...
@@ -210,25 +213,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -210,25 +213,15 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else
:
else
:
self
.
kv_indptr
=
kv_indptr_buf
self
.
kv_indptr
=
kv_indptr_buf
self
.
kv_indices
=
torch
.
empty
(
(
max_bs
*
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
if
q_indptr_decode_buf
is
None
:
if
q_indptr_decode_buf
is
None
:
# A hack to pre-initialize large batch size for dp attention
if
model_runner
.
server_args
.
enable_dp_attention
:
max_bs
=
model_runner
.
server_args
.
dp_size
*
max_bs
self
.
q_indptr_decode
=
torch
.
arange
(
self
.
q_indptr_decode
=
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
else
:
else
:
self
.
q_indptr_decode
=
q_indptr_decode_buf
self
.
q_indptr_decode
=
q_indptr_decode_buf
...
@@ -273,7 +266,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -273,7 +266,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self
.
prefill_cuda_graph_metadata
=
{}
# For verify
self
.
prefill_cuda_graph_metadata
=
{}
# For verify
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -331,9 +323,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -331,9 +323,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
max_num_tokens
:
int
,
max_num_tokens
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
self
.
cuda_graph_kv_indices
=
(
if
kv_indices_buf
is
None
:
self
.
kv_indices
.
clone
()
if
kv_indices_buf
is
None
else
kv_indices_buf
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
)
else
:
cuda_graph_kv_indices
=
kv_indices_buf
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
self
.
cuda_graph_qo_indptr
=
self
.
q_indptr_decode
.
clone
()
self
.
cuda_graph_qo_indptr
=
self
.
q_indptr_decode
.
clone
()
self
.
cuda_graph_kv_indptr
=
self
.
kv_indptr
.
clone
()
self
.
cuda_graph_kv_indptr
=
self
.
kv_indptr
.
clone
()
self
.
cuda_graph_kv_lens
=
torch
.
ones
(
self
.
cuda_graph_kv_lens
=
torch
.
ones
(
...
@@ -359,7 +358,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -359,7 +358,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
self
.
workspace_buffer
,
...
@@ -370,6 +368,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -370,6 +368,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
kv_len_arr
=
self
.
cuda_graph_kv_lens
[:
num_tokens
],
kv_len_arr
=
self
.
cuda_graph_kv_lens
[:
num_tokens
],
backend
=
"auto"
,
backend
=
"auto"
,
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
req_pool_indices
,
...
@@ -440,13 +439,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -440,13 +439,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
assert
seq_lens_cpu
is
not
None
assert
seq_lens_cpu
is
not
None
kv_len_arr_cpu
=
seq_lens_cpu
[:
bs
]
kv_len_arr_cpu
=
seq_lens_cpu
[:
bs
]
num_pages_per_req
=
(
seq_lens_cpu
+
self
.
page_size
-
1
)
//
self
.
page_size
self
.
cuda_graph_kv_indptr_cpu
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
self
.
cuda_graph_kv_indptr_cpu
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
num_pages_per_req
,
dim
=
0
kv_len_arr_cpu
,
dim
=
0
)
)
self
.
fast_decode_kwargs
.
update
(
self
.
fast_decode_kwargs
.
update
(
{
{
...
@@ -455,6 +452,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -455,6 +452,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
"kv_len_arr_cpu"
:
kv_len_arr_cpu
,
"kv_len_arr_cpu"
:
kv_len_arr_cpu
,
}
}
)
)
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens
[:
bs
],
...
@@ -534,6 +532,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -534,6 +532,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope
=
q_rope
.
view
(
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
)
if
self
.
forward_metadata
.
use_ragged
:
if
self
.
forward_metadata
.
use_ragged
:
# ragged prefill
# ragged prefill
if
q_rope
is
not
None
:
if
q_rope
is
not
None
:
...
@@ -554,8 +553,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -554,8 +553,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
q
.
dtype
q
.
dtype
)
)
k_buf
=
k_buf
.
view
(
-
1
,
self
.
page_size
,
k_buf
.
shape
[
-
1
])
if
q_rope
is
None
:
if
q_rope
is
None
:
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q
,
q_rope
=
(
q
,
q_rope
=
(
...
@@ -617,17 +614,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -617,17 +614,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_nope
=
reshaped_q
[:,
:,
:
layer
.
v_head_dim
]
q_nope
=
reshaped_q
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
reshaped_q
[:,
:,
layer
.
v_head_dim
:]
q_rope
=
reshaped_q
[:,
:,
layer
.
v_head_dim
:]
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
k_buf
fer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
q
.
dtype
q
.
dtype
)
)
k_buf
=
k_buf
.
view
(
-
1
,
self
.
page_size
,
k_buf
.
shape
[
-
1
])
o
=
q_nope
.
new_empty
(
q_nope
.
shape
)
o
=
q_nope
.
new_empty
(
q_nope
.
shape
)
# Direct call to run without the wrapper
o
=
decode_wrapper
.
run
(
o
=
decode_wrapper
.
run
(
q_nope
,
q_nope
,
q_rope
,
q_rope
,
k_buf
[:,
:,
:
layer
.
v_head_dim
],
k_buf
fer
[:,
:,
:
layer
.
v_head_dim
],
k_buf
[:,
:,
layer
.
v_head_dim
:],
k_buf
fer
[:,
:,
layer
.
v_head_dim
:],
out
=
o
,
out
=
o
,
)
)
...
@@ -646,10 +643,9 @@ class FlashInferMLAIndicesUpdaterDecode:
...
@@ -646,10 +643,9 @@ class FlashInferMLAIndicesUpdaterDecode:
self
.
scaling
=
model_runner
.
model_config
.
scaling
self
.
scaling
=
model_runner
.
model_config
.
scaling
self
.
data_type
=
model_runner
.
dtype
self
.
data_type
=
model_runner
.
dtype
self
.
attn_backend
=
attn_backend
self
.
attn_backend
=
attn_backend
self
.
page_size
=
model_runner
.
page_size
# Buffers and wrappers
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_indices
=
attn_backend
.
kv_indices
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
q_indptr
=
attn_backend
.
q_indptr_decode
self
.
q_indptr
=
attn_backend
.
q_indptr_decode
...
@@ -693,17 +689,13 @@ class FlashInferMLAIndicesUpdaterDecode:
...
@@ -693,17 +689,13 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_lens
=
paged_kernel_lens
.
to
(
torch
.
int32
)
kv_lens
=
paged_kernel_lens
.
to
(
torch
.
int32
)
sm_scale
=
self
.
scaling
sm_scale
=
self
.
scaling
if
spec_info
is
None
:
if
spec_info
is
None
:
num_pages_per_req
=
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
paged_kernel_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
num_pages_per_req
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
(
kv_indices
=
(
self
.
kv_indices
[:
kv_indptr
[
-
1
]]
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
not
init_metadata_replay
if
not
init_metadata_replay
else
fast_decode_kwargs
[
"kv_indices"
]
else
fast_decode_kwargs
[
"kv_indices"
]
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
,
req_pool_indices
,
...
@@ -712,40 +704,39 @@ class FlashInferMLAIndicesUpdaterDecode:
...
@@ -712,40 +704,39 @@ class FlashInferMLAIndicesUpdaterDecode:
None
,
None
,
kv_indices
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
self
.
req_to_token
.
shape
[
1
],
self
.
page_size
,
)
)
else
:
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
if
not
init_metadata_replay
:
if
not
init_metadata_replay
:
wrapper
.
plan
(
wrapper
.
plan
(
qo_indptr
=
q_indptr
,
q_indptr
,
kv_indptr
=
kv_indptr
,
kv_indptr
,
kv_indices
=
kv_indices
,
kv_indices
,
kv_len_arr
=
kv_lens
,
kv_lens
,
num_heads
=
self
.
num_local_heads
,
self
.
num_local_heads
,
head_dim_ckv
=
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
head_dim_kpe
=
self
.
qk_rope_head_dim
,
self
.
qk_rope_head_dim
,
page_size
=
self
.
page_size
,
1
,
causal
=
False
,
False
,
sm_scale
=
sm_scale
,
sm_scale
,
q_data_type
=
self
.
data_type
,
self
.
data_type
,
kv_data_type
=
self
.
data_type
,
self
.
data_type
,
)
)
else
:
else
:
wrapper
.
plan
(
wrapper
.
plan
(
qo_indptr_cpu
=
fast_decode_kwargs
[
"qo_indptr_cpu"
],
fast_decode_kwargs
[
"qo_indptr_cpu"
],
kv_indptr_cpu
=
fast_decode_kwargs
[
"kv_indptr_cpu"
],
fast_decode_kwargs
[
"kv_indptr_cpu"
],
kv_indices
=
kv_indices
,
kv_indices
,
kv_len_arr_cpu
=
fast_decode_kwargs
[
"kv_len_arr_cpu"
],
fast_decode_kwargs
[
"kv_len_arr_cpu"
],
num_heads
=
self
.
num_local_heads
,
self
.
num_local_heads
,
head_dim_ckv
=
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
head_dim_kpe
=
self
.
qk_rope_head_dim
,
self
.
qk_rope_head_dim
,
page_size
=
self
.
page_size
,
1
,
causal
=
False
,
False
,
sm_scale
=
sm_scale
,
sm_scale
,
q_data_type
=
self
.
data_type
,
self
.
data_type
,
kv_data_type
=
self
.
data_type
,
self
.
data_type
,
)
)
...
@@ -767,14 +758,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -767,14 +758,12 @@ class FlashInferMLAIndicesUpdaterPrefill:
# Buffers and wrappers
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
kv_indices
=
attn_backend
.
kv_indices
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
prefill_wrapper_ragged
=
attn_backend
.
prefill_wrapper_ragged
self
.
prefill_wrapper_ragged
=
attn_backend
.
prefill_wrapper_ragged
self
.
page_size
=
model_runner
.
page_size
def
update
(
def
update
(
self
,
self
,
req_pool_indices
:
torch
.
T
e
nsor
,
req_pool_indices
:
torch
.
Tn
e
sor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
...
@@ -788,6 +777,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -788,6 +777,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
else
:
else
:
paged_kernel_lens
=
seq_lens
paged_kernel_lens
=
seq_lens
paged_kernel_lens_sum
=
seq_lens_sum
paged_kernel_lens_sum
=
seq_lens_sum
self
.
call_begin_forward
(
self
.
call_begin_forward
(
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrapper_ragged
,
prefill_wrapper_paged
,
prefill_wrapper_paged
,
...
@@ -821,12 +811,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -821,12 +811,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
if
spec_info
is
None
:
if
spec_info
is
None
:
assert
len
(
seq_lens
)
==
len
(
req_pool_indices
)
assert
len
(
seq_lens
)
==
len
(
req_pool_indices
)
num_pages_per_req
=
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
paged_kernel_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
num_pages_per_req
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
kv_indices
[:
kv_indptr
[
-
1
]]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
req_pool_indices
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
,
req_pool_indices
,
...
@@ -835,7 +826,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -835,7 +826,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
None
,
None
,
kv_indices
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
self
.
req_to_token
.
shape
[
1
],
self
.
page_size
,
)
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
...
@@ -853,6 +843,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -853,6 +843,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
self
.
req_to_token
,
self
.
req_to_token
,
)
)
)
)
if
use_ragged
:
if
use_ragged
:
# ragged prefill
# ragged prefill
wrapper_ragged
.
begin_forward
(
wrapper_ragged
.
begin_forward
(
...
@@ -867,26 +858,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -867,26 +858,20 @@ class FlashInferMLAIndicesUpdaterPrefill:
)
)
else
:
else
:
# mla paged prefill
# mla paged prefill
if
spec_info
is
not
None
:
kv_len_arr
=
kv_indptr
[
1
:]
-
kv_indptr
[:
-
1
]
assert
(
self
.
page_size
==
1
),
"Only page_size=1 is supported for flashinfer backend with speculative decoding"
kv_lens
=
kv_indptr
[
1
:]
-
kv_indptr
[:
-
1
]
else
:
kv_lens
=
paged_kernel_lens
.
to
(
torch
.
int32
)
wrapper_paged
.
plan
(
wrapper_paged
.
plan
(
qo_indptr
=
qo_indptr
,
qo_indptr
,
kv_indptr
=
kv_indptr
,
kv_indptr
,
kv_indices
=
kv_indices
,
kv_indices
,
kv_len_arr
=
kv_lens
,
kv_len_arr
,
num_heads
=
self
.
num_local_heads
,
self
.
num_local_heads
,
head_dim_ckv
=
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
head_dim_kpe
=
self
.
qk_rope_head_dim
,
self
.
qk_rope_head_dim
,
page_size
=
self
.
page_size
,
1
,
causal
=
True
,
True
,
sm_scale
=
sm_scale
,
sm_scale
,
q_data_type
=
self
.
q_data_type
,
self
.
q_data_type
,
kv_data_type
=
self
.
data_type
,
self
.
data_type
,
)
)
...
@@ -981,7 +966,6 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -981,7 +966,6 @@ class FlashInferMLAMultiStepDraftBackend:
call_fn
(
i
,
forward_batch
)
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
zeros
(
(
(
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
...
@@ -1017,7 +1001,6 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -1017,7 +1001,6 @@ class FlashInferMLAMultiStepDraftBackend:
)
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
,
...
@@ -1034,7 +1017,6 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -1034,7 +1017,6 @@ class FlashInferMLAMultiStepDraftBackend:
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
...
...
python/sglang/srt/layers/attention/utils.py
View file @
ebd9dbe7
...
@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
...
@@ -9,89 +9,18 @@ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
@
triton
.
jit
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_indptr
,
kv_start_idx
,
kv_start_idx
,
kv_indices_ptr
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
=
1
,
):
):
"""
Create KV indices for FlashInfer attention backend.
This Triton kernel builds a lookup table that maps from logical request/token
coordinates to physical token locations in the global KV cache pool. It's used
by FlashInfer attention backends to efficiently access scattered KV cache data.
The kernel processes each request in parallel and converts the req_to_token
lookup table into a flat list of token indices that can be used by attention kernels.
general idea:
blocktables/kv_indices_ptr = [batch_size * max_pages(for graph mode with
fixed number of pages)]
max_pages = max_context_len / PAGED_SIZE
kv_indices_ptr will store the flat list of the pages used by each request
Args:
Inputs Arguments (non mutable):
req_to_token_ptr: Request to token location look up table
Shape: [max_batch, max_context_len]
req_pool_indices_ptr: Request to pool index look up table. Each request uses
one pool.
Shape: [batch_size]
page_kernel_lens_ptr: sequence lengths per request
Shape: [batch_size]
kv_indptr: Should be computed based on number of pages used by each request.
It is used by flashinfer attention kernels to index into the kv_indices_ptr.
per request.
Shape: [batch_size + 1]
kv_indptr[i] = start index in kv_indices for request i
kv_start_idx: Pointer to array containing start offsets for each request in SGL.
Can be None. If provided, adds offset to token positions.
req_to_token_ptr_stride: Stride for the second dimension of req_to_token.
Equal to max_context_len.
PAGED_SIZE: Number of tokens per page. Default is 1 for FlashInfer.
Outputs:
kv_indices_ptr: Pointer to output array where KV indices will be stored.
Shape:[total-num-pages],
where total_num_pages = sum(seq_lens // PAGED_SIZE)
Example:
If we have:
- req_pool_indices = [0, 1] (request 0 uses pool 0, request 1 uses pool 1)
- page_kernel_lens = [3, 2] (request 0 has 3 tokens, request 1 has 2 tokens)
- req_to_token = [[10, 11, 12, -1], [20, 21, -1, -1]] (tokens are the elements
in radix tree, use them as a pointer to the token location in the kv_indices_ptr)
The kernel will output:
If PAGE_SIZE = 1:
packed
- kv_indptr (passed in as input arg): [0,3,5]
- kv_indices = [10, 11, 12, 20, 21]
padded - max_pages is 10 tokens per req
- kv_indptr (passed in as input arg): [0,10, 20]
- kv_indices = [10, 11, 12, -1, -1, -1, -1, -1, -1, -1,
20, 21, -1, -1, -1, -1, -1, -1, -1, -1]
If PAGE_SIZE = 2
packed:
- kv_indptr (passed in as input arg): [0,3,4]
- kv_indices = [5,6,10]
padded: max_pages is 4
- kv_indptr (passed in as input arg): [0,4,8,..] (note that 4 is the max_pages)
- kv_indices = [5, 6, -1, -1,
10, -1, -1, -1]
This allows attention kernels to directly access the correct KV cache
entries for each request's tokens.
"""
BLOCK_SIZE
:
tl
.
constexpr
=
512
BLOCK_SIZE
:
tl
.
constexpr
=
512
NUM_PAGES_PER_BLOCK
:
tl
.
constexpr
=
BLOCK_SIZE
//
PAGE_SIZE
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
# find the req pool idx, this is for batch to token
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
...
@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
...
@@ -102,27 +31,19 @@ def create_flashinfer_kv_indices_triton(
kv_end
=
kv_start
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
kv_range
=
kv_end
-
kv_start
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
num_pages
=
tl
.
cdiv
(
kv_range
,
PAGE_SIZE
)
for
i
in
range
(
num_loop
):
num_loops
=
tl
.
cdiv
(
kv_range
,
BLOCK_SIZE
)
# index into req_to_token_ptr needs to be int64
req_to_token_block_start
=
(
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
).
to
(
tl
.
int64
)
+
i
*
BLOCK_SIZE
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
kv_start
mask
=
offset
<
kv_end
-
kv_start
)
data
=
tl
.
load
(
for
i
in
range
(
num_loops
):
req_to_token_ptr
token_offsets_in_block
=
(
+
req_pool_index
*
req_to_token_ptr_stride
tl
.
arange
(
0
,
NUM_PAGES_PER_BLOCK
).
to
(
tl
.
int64
)
+
i
*
NUM_PAGES_PER_BLOCK
+
kv_start
)
*
PAGE_SIZE
+
offset
,
page_offsets_in_block
=
token_offsets_in_block
//
PAGE_SIZE
mask
=
mask
,
valid_tokens
=
token_offsets_in_block
<
kv_range
valid_pages
=
page_offsets_in_block
<
num_pages
token_numbers
=
tl
.
load
(
req_to_token_block_start
+
token_offsets_in_block
,
mask
=
valid_tokens
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
page_offsets_in_block
,
token_numbers
//
PAGE_SIZE
,
# write the page numbers to kv_indices_ptr
mask
=
valid_pages
,
)
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
@
triton
.
jit
@
triton
.
jit
...
...
python/sglang/srt/server_args.py
View file @
ebd9dbe7
...
@@ -639,10 +639,6 @@ class ServerArgs:
...
@@ -639,10 +639,6 @@ class ServerArgs:
logger
.
warning
(
logger
.
warning
(
"DeepSeek MTP does not require setting speculative_draft_model_path."
"DeepSeek MTP does not require setting speculative_draft_model_path."
)
)
if
self
.
page_size
!=
1
and
self
.
attention_backend
==
"flashinfer"
:
raise
ValueError
(
"Speculative decoding with page_size != 1 is not supported. Please set page_size to 1."
)
# Auto choose parameters
# Auto choose parameters
if
self
.
speculative_num_steps
is
None
:
if
self
.
speculative_num_steps
is
None
:
...
...
test/srt/test_create_kvindices.py
View file @
ebd9dbe7
...
@@ -4,10 +4,7 @@ import unittest
...
@@ -4,10 +4,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.layers.attention.utils
import
(
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
create_flashinfer_kv_indices_triton
,
create_flashmla_kv_indices_triton
,
)
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -18,14 +15,10 @@ class TestCreateKvIndices(CustomTestCase):
...
@@ -18,14 +15,10 @@ class TestCreateKvIndices(CustomTestCase):
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
def
_run_test
(
self
,
batch
,
max_batch
,
max_context_len
,
page_size
):
def
_run_test
(
self
,
batch
,
max_batch
,
max_context_len
):
np
.
random
.
seed
(
9
)
PAGE_SIZE
=
page_size
req_to_token
=
torch
.
arange
(
req_to_token
=
torch
.
arange
(
max_batch
*
max_context_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
max_batch
*
max_context_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
).
reshape
((
max_batch
,
max_context_len
))
).
reshape
((
max_batch
,
max_context_len
))
# the block table
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
=
torch
.
tensor
(
torch
.
from_numpy
(
torch
.
from_numpy
(
np
.
random
.
choice
(
range
(
max_batch
),
size
=
batch
,
replace
=
False
)
np
.
random
.
choice
(
range
(
max_batch
),
size
=
batch
,
replace
=
False
)
...
@@ -33,84 +26,49 @@ class TestCreateKvIndices(CustomTestCase):
...
@@ -33,84 +26,49 @@ class TestCreateKvIndices(CustomTestCase):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
seq
_lens
=
torch
.
tensor
(
paged_kernel
_lens
=
torch
.
tensor
(
torch
.
from_numpy
(
torch
.
from_numpy
(
np
.
random
.
choice
(
range
(
max_context_len
),
size
=
batch
,
replace
=
False
)
np
.
random
.
choice
(
range
(
max_context_len
),
size
=
batch
,
replace
=
False
)
),
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
num_pages_per_req
=
(
seq_lens
+
PAGE_SIZE
-
1
)
//
PAGE_SIZE
kv_indptr
=
torch
.
zeros
((
batch
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
batch
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
num_pages_per_req
,
dim
=
0
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
# ref
# ref
kv_indices_ref
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
for
i
in
range
(
batch
):
kv_indices_ref
=
torch
.
cat
(
kv_indptr_req
=
kv_indptr
[
i
]
[
num_toks_seq
=
seq_lens_cpu
[
i
]
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]]
curr_req_pool
=
req_pool_indices_cpu
[
i
]
for
i
in
range
(
batch
)
curr_num_pages
=
num_pages_per_req
[
i
]
],
curr_token_ids
=
req_to_token
[
curr_req_pool
]
dim
=
0
,
curr_pages
=
(
curr_token_ids
[:
num_toks_seq
]
//
PAGE_SIZE
).
unique
()
).
contiguous
()
assert
(
len
(
curr_pages
)
==
curr_num_pages
),
f
"req
{
i
}
has #
{
curr_num_pages
}
pages, but got
{
len
(
curr_pages
)
}
pages"
kv_indices_ref
[
kv_indptr_req
:
kv_indptr_req
+
curr_num_pages
]
=
curr_pages
# triton
# triton
kv_indices_triton
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices_triton
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
batch
,)](
create_flashinfer_kv_indices_triton
[(
batch
,)](
req_to_token
,
req_to_token
,
req_pool_indices
,
req_pool_indices
,
seq
_lens
,
paged_kernel
_lens
,
kv_indptr
,
kv_indptr
,
None
,
None
,
kv_indices_triton
,
kv_indices_triton
,
req_to_token
.
size
(
1
),
req_to_token
.
size
(
1
),
PAGE_SIZE
,
)
max_pages
=
max_context_len
//
PAGE_SIZE
kv_indices_flashmla
=
torch
.
empty
(
batch
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
create_flashmla_kv_indices_triton
[(
batch
,)](
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
kv_indices_flashmla
,
req_to_token
.
size
(
1
),
max_pages
,
PAGE_SIZE
,
)
# Check
# Check
self
.
assertTrue
(
torch
.
equal
(
kv_indices_ref
,
kv_indices_triton
))
self
.
assertTrue
(
torch
.
equal
(
kv_indices_ref
,
kv_indices_triton
))
def
test_create_kvindices
(
self
):
def
test_create_kvindices
(
self
):
BATCH
=
[
4
,
37
,
512
,
1786
]
BATCH
=
[
1
,
37
,
1786
]
MAX_BATCH
=
4096
MAX_BATCH
=
4096
MAX_CONTEXT_LEN
=
4096
MAX_CONTEXT_LEN
=
4096
PAGE_SIZE
=
[
1
,
2
,
16
,
64
]
for
batch
in
BATCH
:
# for debug
self
.
_run_test
(
batch
,
MAX_BATCH
,
MAX_CONTEXT_LEN
)
# BATCH = [4]
# MAX_BATCH = 4
# MAX_CONTEXT_LEN = 10
# Test for small batch size
for
page_size
in
PAGE_SIZE
[:
1
]:
print
(
f
"Running test for page size:
{
page_size
}
and batch size:
{
BATCH
[
0
]
}
"
)
self
.
_run_test
(
BATCH
[
0
],
MAX_BATCH
,
MAX_CONTEXT_LEN
,
page_size
)
# Test for larger batch size
for
batch
in
BATCH
[
1
:]:
for
page_size
in
PAGE_SIZE
:
print
(
f
"Running test for batch size:
{
batch
}
and page size:
{
page_size
}
"
)
self
.
_run_test
(
batch
,
MAX_BATCH
,
MAX_CONTEXT_LEN
,
page_size
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_mla_flashinfer.py
View file @
ebd9dbe7
...
@@ -120,49 +120,5 @@ class TestFlashinferMLAMTP(CustomTestCase):
...
@@ -120,49 +120,5 @@ class TestFlashinferMLAMTP(CustomTestCase):
self
.
assertGreater
(
avg_spec_accept_length
,
2.5
)
self
.
assertGreater
(
avg_spec_accept_length
,
2.5
)
class
TestFlashinferMLAPageSize16
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
]
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
:
other_args
.
extend
(
[
"--cuda-graph-max-bs"
,
"4"
,
"--attention-backend"
,
"flashinfer"
,
"--page-size"
,
"16"
,
]
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.615
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment