Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53f7874a
Unverified
Commit
53f7874a
authored
Aug 09, 2025
by
valarLip
Committed by
GitHub
Aug 08, 2025
Browse files
refine aiter_backend for mtp (#7279)
Co-authored-by:
HAI
<
hixiao@gmail.com
>
parent
61a46804
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
387 additions
and
107 deletions
+387
-107
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+370
-107
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+16
-0
No files found.
python/sglang/srt/layers/attention/aiter_backend.py
View file @
53f7874a
...
@@ -32,7 +32,7 @@ try:
...
@@ -32,7 +32,7 @@ try:
mha_batch_prefill_func
,
mha_batch_prefill_func
,
paged_attention_ragged
,
paged_attention_ragged
,
)
)
from
aiter.mla
import
mla_decode_fwd
from
aiter.mla
import
mla_decode_fwd
,
mla_prefill_fwd
except
ImportError
:
except
ImportError
:
print
(
print
(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
...
@@ -52,10 +52,8 @@ class ForwardMetadata:
...
@@ -52,10 +52,8 @@ class ForwardMetadata:
kv_indices
:
torch
.
Tensor
kv_indices
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
kv_last_page_len
:
torch
.
Tensor
kv_last_page_len
:
torch
.
Tensor
max_extend_len
:
int
max_prefix_extend_len
:
int
max_q_len
:
int
max_q_len
:
int
max_kv_len
:
int
max_kv_len
:
Optional
[
int
]
global_workspace_buffer
=
None
global_workspace_buffer
=
None
...
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
self
.
extend_attention_fwd
=
torch
.
compiler
.
disable
(
extend_attention_fwd
)
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
self
.
is_multimodal
=
model_runner
.
model_config
.
is_multimodal
self
.
is_multimodal
=
model_runner
.
model_config
.
is_multimodal
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
num_head
=
(
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
...
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
spec_info
=
forward_batch
.
spec_info
spec_info
=
forward_batch
.
spec_info
qo_indptr
=
None
qo_indptr
=
None
kv_last_page_len
=
None
kv_last_page_len
=
None
max_
extend
_len
=
None
max_
q
_len
=
None
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr
=
self
.
qo_indptr_
[:
bs
+
1
]
qo_indptr
=
self
.
qo_indptr_
[:
bs
+
1
]
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
self
.
kv_last_page_len
[:
bs
],
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
self
.
kv_last_page_len
[:
bs
],
dim
=
0
)
kv_last_page_len
=
self
.
kv_last_page_len
[:
bs
]
kv_last_page_len
=
self
.
kv_last_page_len
[:
bs
]
max_
extend
_len
=
1
max_
q
_len
=
1
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
kv_last_page_len
,
kv_last_page_len
,
max_extend_len
,
max_q_len
,
None
,
None
,
None
,
None
,
)
)
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
if
self
.
use_mla
:
if
self
.
use_mla
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
self
.
mla_indices_updater_prefill
.
update
(
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
prefix_lens
,
forward_batch
.
seq_lens
,
prefix_lens
.
sum
().
item
(),
forward_batch
.
seq_lens_sum
,
forward_batch
.
extend_seq_lens
,
self
.
req_to_token
,
encoder_lens
=
forward_batch
.
encoder_lens
,
)
spec_info
=
None
,
)
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
qo_indptr
,
self
.
mla_indices_updater_prefill
.
kv_last_page_len
,
# self.mla_indices_updater_prefill.kv_last_page_len,
self
.
mla_indices_updater_prefill
.
max_extend_len
,
self
.
kv_last_page_len
[:
bs
],
self
.
mla_indices_updater_prefill
.
max_prefix_extend_len
,
max
(
forward_batch
.
extend_seq_lens_cpu
),
None
,
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
None
,
)
)
else
:
else
:
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
...
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
if
self
.
use_mla
:
if
self
.
use_mla
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
draft_num
=
spec_info
.
draft_token_num
self
.
mla_indices_updater_prefill
.
update
(
kv_lens
=
forward_batch
.
seq_lens
+
draft_num
kv_lens_sum
=
forward_batch
.
seq_lens_sum
+
draft_num
*
bs
device
=
forward_batch
.
seq_lens
.
device
qo_indptr
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
draft_num
,
step
=
draft_num
,
dtype
=
torch
.
int32
,
device
=
device
,
)
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_lens_sum
,
dtype
=
torch
.
int32
,
device
=
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
prefix
_lens
,
kv
_lens
,
prefix_lens
.
sum
().
item
()
,
kv_indptr
,
forward_batch
.
extend_seq_lens
,
None
,
encoder_lens
=
forward_batch
.
encoder_len
s
,
kv_indice
s
,
s
pec_info
=
None
,
s
elf
.
req_to_token
.
stride
(
0
)
,
)
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
qo_indptr
,
self
.
mla_indices_updater_prefill
.
kv_last_page_len
,
# self.mla_indices_updater_prefill.kv_last_page_len,
self
.
mla_indices_updater_prefill
.
max_extend_len
,
self
.
kv_last_page_len
[:
bs
],
self
.
mla_indices_updater_prefill
.
max_prefix_extend_len
,
draft_num
,
None
,
None
,
None
,
)
)
else
:
else
:
...
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
)
...
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
extend_no_prefix
=
False
extend_no_prefix
=
False
else
:
else
:
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
if
self
.
use_mla
:
if
self
.
use_mla
:
self
.
mla_indices_updater_prefill
.
update
(
self
.
mla_indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
prefix_lens
,
forward_batch
.
extend_
prefix_lens
,
prefix_lens
.
sum
().
item
(
),
sum
(
forward_batch
.
extend_prefix_lens_cpu
),
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_seq_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
max
(
forward_batch
.
extend_seq_lens_cpu
),
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
spec_info
=
None
,
spec_info
=
None
,
)
)
self
.
mla_indices_updater_prefill
.
kv_indptr
+=
(
self
.
mla_indices_updater_prefill
.
qo_indptr
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
mla_indices_updater_prefill
.
kv_last_page_len
,
self
.
kv_last_page_len
[:
bs
],
self
.
mla_indices_updater_prefill
.
max_extend_len
,
self
.
mla_indices_updater_prefill
.
max_q_len
,
self
.
mla_indices_updater_prefill
.
max_prefix_extend_len
,
self
.
mla_indices_updater_prefill
.
max_kv_len
,
None
,
None
,
)
)
else
:
else
:
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
...
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
)
...
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
qo_indptr
=
None
qo_indptr
=
None
kv_last_page_len
=
None
kv_last_page_len
=
None
max_
extend
_len
=
None
max_
q
_len
=
None
if
spec_info
is
None
:
if
spec_info
is
None
:
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
...
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
self
.
cuda_graph_kv_last_page_len
[:
bs
],
dim
=
0
self
.
cuda_graph_kv_last_page_len
[:
bs
],
dim
=
0
)
)
max_extend_len
=
1
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
max_q_len
=
1
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
kv_last_page_len
,
kv_last_page_len
,
max_extend_len
,
max_q_len
,
None
,
None
,
None
,
None
,
)
)
...
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
max_extend_len
=
self
.
num_draft_tokens
max_q_len
=
self
.
num_draft_tokens
kv_last_page_len
=
None
self
.
forward_metadata
=
ForwardMetadata
(
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
kv_last_page_len
,
kv_last_page_len
,
max_extend_len
,
max_q_len
,
None
,
None
,
None
,
None
,
)
)
else
:
else
:
...
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
)
elif
forward_mode
.
is_draft_extend
():
num_tokens_per_bs
=
self
.
speculative_num_steps
+
1
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
step
=
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
max_q_len
=
num_tokens_per_bs
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
qo_indptr
,
kv_last_page_len
,
max_q_len
,
None
,
)
else
:
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
...
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
bs
=
len
(
req_pool_indices
)
req_pool_indices
[:
bs
],
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
seq_lens
[:
bs
],
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
seq_lens_sum
,
0
,
prefix_lens
=
None
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
step
=
self
.
num_draft_tokens
,
spec_info
=
spec_info
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_lens
=
seq_lens
+
self
.
num_draft_tokens
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
kv_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
elif
forward_mode
.
is_draft_extend
():
seq_lens
=
seq_lens
[:
bs
]
accept_lens
=
spec_info
.
accept_length
[:
bs
]
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
accept_lens
,
dim
=
0
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
)
else
:
else
:
raise
ValueError
(
"Invalid forward mode"
)
raise
ValueError
(
"Invalid forward mode"
)
...
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
)
)
if
self
.
use_mla
:
if
self
.
use_mla
:
max_
extend
_len
=
self
.
forward_metadata
.
max_
extend
_len
max_
q
_len
=
self
.
forward_metadata
.
max_
q
_len
max_
prefix_extend
_len
=
self
.
forward_metadata
.
max_
prefix_extend
_len
max_
kv
_len
=
self
.
forward_metadata
.
max_
kv
_len
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
kv_indices
=
self
.
forward_metadata
.
kv_indices
kv_last_page_lens
=
self
.
forward_metadata
.
kv_last_page_len
qo_indptr
=
self
.
forward_metadata
.
qo_indptr
qo_indptr
=
self
.
forward_metadata
.
qo_indptr
K_Buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
K_Buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
V_Buffer
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
V_Buffer
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
...
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
v
,
v
,
qo_indptr
,
qo_indptr
,
qo_indptr
,
qo_indptr
,
max_
extend
_len
,
max_
q
_len
,
max_
extend
_len
,
max_
q
_len
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
)
)
...
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
v
,
v
,
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
max_
extend
_len
,
max_
q
_len
,
max_
prefix_extend
_len
,
max_
kv
_len
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
)
)
return
o
return
o
elif
forward_batch
.
forward_mode
.
is_target_verify
():
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
))
mla_decode_fwd
(
q
,
K_Buffer
.
view
(
-
1
,
1
,
1
,
layer
.
qk_head_dim
),
o
,
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_q_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
K_Buffer
=
K_Buffer
.
view
(
-
1
,
1
,
layer
.
qk_head_dim
)
return
o
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
))
causal
=
True
sliding_window_size
=
-
1
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
mla_prefill_fwd
(
q
,
K_Buffer
.
view
(
-
1
,
1
,
1
,
layer
.
qk_head_dim
),
o
,
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_q_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
K_Buffer
=
K_Buffer
.
view
(
-
1
,
1
,
layer
.
qk_head_dim
)
return
o
# self.extend_attention_fwd(
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
# k.contiguous(),
# v.contiguous(),
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
# self.forward_metadata.qo_indptr,
# kv_indptr,
# kv_indices,
# None,
# causal,
# None,
# self.forward_metadata.max_q_len,
# layer.scaling,
# layer.logit_cap,
# sliding_window_size,
# )
# return o
else
:
raise
ValueError
(
f
"Invalid forward mode for MLA prefill:
{
forward_batch
.
forward_mode
=
}
"
)
else
:
else
:
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
...
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_
extend
_len
,
self
.
forward_metadata
.
max_
q
_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
...
@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
self
.
kv_indices
=
None
self
.
kv_indices
=
None
self
.
qo_indptr
=
None
self
.
qo_indptr
=
None
self
.
kv_last_page_len
=
None
self
.
kv_last_page_len
=
None
self
.
max_
extend
_len
=
0
self
.
max_
q
_len
=
0
self
.
max_
prefix_extend
_len
=
0
self
.
max_
kv
_len
=
0
def
update
(
def
update
(
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
prefix
_lens
:
torch
.
Tensor
,
kv
_lens
:
torch
.
Tensor
,
prefix
_lens_sum
:
int
,
kv
_lens_sum
:
int
,
extend_lens
:
torch
.
Tensor
,
extend_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
max_q_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
...
@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
...
@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
def
update_single_wrapper
(
def
update_single_wrapper
(
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
prefix
_lens
:
torch
.
Tensor
,
kv
_lens
:
torch
.
Tensor
,
prefix
_lens_sum
:
int
,
kv
_lens_sum
:
int
,
extend_lens
:
torch
.
Tensor
,
extend_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
max_q_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
SpecInfo
],
):
):
paged_kernel_lens
=
prefix_lens
paged_kernel_lens_sum
=
prefix_lens_sum
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
kv_indptr
=
self
.
attn_backend
.
kv_indptr
kv_indptr
=
self
.
attn_backend
.
kv_indptr
if
spec_info
is
None
:
if
spec_info
is
None
:
# Normal extend
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel
_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv
_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
empty
(
paged_kernel
_lens_sum
,
kv
_lens_sum
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
req_pool_indices
.
device
,
device
=
req_pool_indices
.
device
,
)
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
,
req_pool_indices
,
paged_kernel
_lens
,
kv
_lens
,
kv_indptr
,
kv_indptr
,
None
,
None
,
kv_indices
,
kv_indices
,
...
@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
...
@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
qo_indptr
=
self
.
attn_backend
.
qo_indptr
qo_indptr
=
self
.
attn_backend
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
extend_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
extend_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
max_extend_len
=
torch
.
max
(
extend_lens
).
item
()
max_prefix_extend_len
=
torch
.
max
(
extend_lens
+
paged_kernel_lens
).
item
()
kv_indptr
+=
qo_indptr
else
:
else
:
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
req_pool_indices
,
paged_kernel
_lens
,
kv
_lens
,
paged_kernel
_lens_sum
,
kv
_lens_sum
,
self
.
req_to_token
,
self
.
req_to_token
,
)
)
)
)
...
@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
...
@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
self
.
kv_indptr
=
kv_indptr
self
.
kv_indptr
=
kv_indptr
self
.
kv_indices
=
kv_indices
self
.
kv_indices
=
kv_indices
self
.
qo_indptr
=
qo_indptr
self
.
qo_indptr
=
qo_indptr
self
.
max_extend_len
=
max_extend_len
self
.
max_q_len
=
max_q_len
self
.
max_prefix_extend_len
=
max_prefix_extend_len
self
.
max_kv_len
=
max_kv_len
class
AiterMultiStepDraftBackend
:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.eagle_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
AiterAttnBackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
page_size
=
model_runner
.
server_args
.
page_size
assert
self
.
page_size
==
1
,
"Page size must be 1"
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
):
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
forward_batch
.
req_pool_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
seq_lens
,
kv_indices_buffer
,
self
.
kv_indptr
,
forward_batch
.
positions
,
self
.
pool_len
,
kv_indices_buffer
.
shape
[
1
],
self
.
kv_indptr
.
shape
[
1
],
triton
.
next_power_of_2
(
num_seqs
),
triton
.
next_power_of_2
(
self
.
speculative_num_steps
),
triton
.
next_power_of_2
(
bs
),
self
.
page_size
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
]
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
empty
(
(
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
def
call_fn
(
i
,
forward_batch
):
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
forward_batch
.
spec_info
.
kv_indices
=
(
forward_batch
.
spec_info
.
kv_indices
.
clone
()
)
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
kv_indices
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_num_tokens
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
None
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
python/sglang/srt/managers/schedule_batch.py
View file @
53f7874a
...
@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or
attention_backend_str
==
"cutlass_mla"
or
attention_backend_str
==
"cutlass_mla"
or
attention_backend_str
==
"ascend"
or
attention_backend_str
==
"ascend"
or
attention_backend_str
==
"trtllm_mha"
or
attention_backend_str
==
"trtllm_mha"
or
attention_backend_str
==
"aiter"
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
):
seq_lens_cpu
=
(
seq_lens_cpu
=
(
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
53f7874a
...
@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
...
@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
self
.
draft_model_runner
,
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
(
AiterAttnBackend
,
AiterMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
AiterMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
FlashAttentionBackend
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment