Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53f7874a
Unverified
Commit
53f7874a
authored
Aug 09, 2025
by
valarLip
Committed by
GitHub
Aug 08, 2025
Browse files
refine aiter_backend for mtp (#7279)
Co-authored-by:
HAI
<
hixiao@gmail.com
>
parent
61a46804
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
387 additions
and
107 deletions
+387
-107
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+370
-107
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+16
-0
No files found.
python/sglang/srt/layers/attention/aiter_backend.py
View file @
53f7874a
...
...
@@ -32,7 +32,7 @@ try:
mha_batch_prefill_func
,
paged_attention_ragged
,
)
from
aiter.mla
import
mla_decode_fwd
from
aiter.mla
import
mla_decode_fwd
,
mla_prefill_fwd
except
ImportError
:
print
(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
...
...
@@ -52,10 +52,8 @@ class ForwardMetadata:
kv_indices
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
kv_last_page_len
:
torch
.
Tensor
max_extend_len
:
int
max_prefix_extend_len
:
int
max_q_len
:
int
max_kv_len
:
int
max_kv_len
:
Optional
[
int
]
global_workspace_buffer
=
None
...
...
@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
self
.
extend_attention_fwd
=
torch
.
compiler
.
disable
(
extend_attention_fwd
)
self
.
device
=
model_runner
.
device
self
.
is_multimodal
=
model_runner
.
model_config
.
is_multimodal
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
speculative_num_steps
=
model_runner
.
server_args
.
speculative_num_steps
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
...
...
@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
spec_info
=
forward_batch
.
spec_info
qo_indptr
=
None
kv_last_page_len
=
None
max_
extend
_len
=
None
max_
q
_len
=
None
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr
=
self
.
qo_indptr_
[:
bs
+
1
]
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
self
.
kv_last_page_len
[:
bs
],
dim
=
0
)
kv_last_page_len
=
self
.
kv_last_page_len
[:
bs
]
max_
extend
_len
=
1
max_
q
_len
=
1
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
qo_indptr
,
kv_last_page_len
,
max_extend_len
,
None
,
None
,
max_q_len
,
None
,
)
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
if
self
.
use_mla
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
self
.
mla_indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
prefix_lens
,
prefix_lens
.
sum
().
item
(),
forward_batch
.
extend_seq_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
None
,
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
self
.
req_to_token
,
)
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
mla_indices_updater_prefill
.
kv_last_page_len
,
self
.
mla_indices_updater_prefill
.
max_extend_len
,
self
.
mla_indices_updater_prefill
.
max_prefix_extend_len
,
None
,
None
,
kv_indptr
,
kv_indices
,
qo_indptr
,
# self.mla_indices_updater_prefill.kv_last_page_len,
self
.
kv_last_page_len
[:
bs
],
max
(
forward_batch
.
extend_seq_lens_cpu
),
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
)
else
:
self
.
indices_updater_prefill
.
update
(
...
...
@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
if
self
.
use_mla
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
self
.
mla_indices_updater_prefill
.
update
(
draft_num
=
spec_info
.
draft_token_num
kv_lens
=
forward_batch
.
seq_lens
+
draft_num
kv_lens_sum
=
forward_batch
.
seq_lens_sum
+
draft_num
*
bs
device
=
forward_batch
.
seq_lens
.
device
qo_indptr
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
draft_num
,
step
=
draft_num
,
dtype
=
torch
.
int32
,
device
=
device
,
)
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_lens_sum
,
dtype
=
torch
.
int32
,
device
=
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
prefix
_lens
,
prefix_lens
.
sum
().
item
()
,
forward_batch
.
extend_seq_lens
,
encoder_lens
=
forward_batch
.
encoder_len
s
,
s
pec_info
=
None
,
kv
_lens
,
kv_indptr
,
None
,
kv_indice
s
,
s
elf
.
req_to_token
.
stride
(
0
)
,
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
mla_indices_updater_prefill
.
kv_last_page_len
,
self
.
mla_indices_updater_prefill
.
max_extend_len
,
self
.
mla_indices_updater_prefill
.
max_prefix_extend_len
,
None
,
kv_indptr
,
kv_indices
,
qo_indptr
,
# self.mla_indices_updater_prefill.kv_last_page_len,
self
.
kv_last_page_len
[:
bs
],
draft_num
,
None
,
)
else
:
...
...
@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
...
...
@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
extend_no_prefix
=
False
else
:
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
if
self
.
use_mla
:
self
.
mla_indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
prefix_lens
,
prefix_lens
.
sum
().
item
(
),
forward_batch
.
extend_
prefix_lens
,
sum
(
forward_batch
.
extend_prefix_lens_cpu
),
forward_batch
.
extend_seq_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
max
(
forward_batch
.
extend_seq_lens_cpu
),
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
spec_info
=
None
,
)
self
.
mla_indices_updater_prefill
.
kv_indptr
+=
(
self
.
mla_indices_updater_prefill
.
qo_indptr
)
self
.
forward_metadata
=
ForwardMetadata
(
self
.
mla_indices_updater_prefill
.
kv_indptr
,
self
.
mla_indices_updater_prefill
.
kv_indices
,
self
.
mla_indices_updater_prefill
.
qo_indptr
,
self
.
mla_indices_updater_prefill
.
kv_last_page_len
,
self
.
mla_indices_updater_prefill
.
max_extend_len
,
self
.
mla_indices_updater_prefill
.
max_prefix_extend_len
,
None
,
None
,
self
.
kv_last_page_len
[:
bs
],
self
.
mla_indices_updater_prefill
.
max_q_len
,
self
.
mla_indices_updater_prefill
.
max_kv_len
,
)
else
:
self
.
indices_updater_prefill
.
update
(
...
...
@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
...
...
@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
if
forward_mode
.
is_decode_or_idle
():
qo_indptr
=
None
kv_last_page_len
=
None
max_
extend
_len
=
None
max_
q
_len
=
None
if
spec_info
is
None
:
kv_indptr
=
self
.
kv_indptr
...
...
@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
self
.
cuda_graph_kv_last_page_len
[:
bs
],
dim
=
0
)
max_extend_len
=
1
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
max_q_len
=
1
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
qo_indptr
,
kv_last_page_len
,
max_extend_len
,
None
,
None
,
max_q_len
,
None
,
)
...
...
@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
max_extend_len
=
self
.
num_draft_tokens
kv_last_page_len
=
None
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
max_q_len
=
self
.
num_draft_tokens
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
qo_indptr
,
kv_last_page_len
,
max_extend_len
,
None
,
None
,
max_q_len
,
None
,
)
else
:
...
...
@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
kv_indices
,
None
,
None
,
None
,
None
,
self
.
indices_updater_prefill
.
max_q_len
,
self
.
indices_updater_prefill
.
max_kv_len
,
)
elif
forward_mode
.
is_draft_extend
():
num_tokens_per_bs
=
self
.
speculative_num_steps
+
1
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
step
=
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
kv_last_page_len
=
self
.
cuda_graph_kv_last_page_len
[:
bs
]
max_q_len
=
num_tokens_per_bs
self
.
forward_metadata
=
ForwardMetadata
(
kv_indptr
,
kv_indices
,
qo_indptr
,
kv_last_page_len
,
max_q_len
,
None
,
)
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
...
...
@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
elif
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
prefix_lens
=
None
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
bs
=
len
(
req_pool_indices
)
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[:
bs
+
1
]
=
torch
.
arange
(
0
,
(
1
+
bs
)
*
self
.
num_draft_tokens
,
step
=
self
.
num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
kv_lens
=
seq_lens
+
self
.
num_draft_tokens
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
kv_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
elif
forward_mode
.
is_draft_extend
():
seq_lens
=
seq_lens
[:
bs
]
accept_lens
=
spec_info
.
accept_length
[:
bs
]
qo_indptr
=
self
.
qo_indptr
[:
bs
+
1
]
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
accept_lens
,
dim
=
0
)
kv_indptr
=
self
.
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
else
:
raise
ValueError
(
"Invalid forward mode"
)
...
...
@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
)
if
self
.
use_mla
:
max_
extend
_len
=
self
.
forward_metadata
.
max_
extend
_len
max_
prefix_extend
_len
=
self
.
forward_metadata
.
max_
prefix_extend
_len
max_
q
_len
=
self
.
forward_metadata
.
max_
q
_len
max_
kv
_len
=
self
.
forward_metadata
.
max_
kv
_len
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
kv_last_page_lens
=
self
.
forward_metadata
.
kv_last_page_len
qo_indptr
=
self
.
forward_metadata
.
qo_indptr
K_Buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
V_Buffer
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
...
...
@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
v
,
qo_indptr
,
qo_indptr
,
max_
extend
_len
,
max_
extend
_len
,
max_
q
_len
,
max_
q
_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
)
...
...
@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
v
,
qo_indptr
,
kv_indptr
,
max_
extend
_len
,
max_
prefix_extend
_len
,
max_
q
_len
,
max_
kv
_len
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
)
return
o
elif
forward_batch
.
forward_mode
.
is_target_verify
():
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
))
mla_decode_fwd
(
q
,
K_Buffer
.
view
(
-
1
,
1
,
1
,
layer
.
qk_head_dim
),
o
,
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_q_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
K_Buffer
=
K_Buffer
.
view
(
-
1
,
1
,
layer
.
qk_head_dim
)
return
o
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
))
causal
=
True
sliding_window_size
=
-
1
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
mla_prefill_fwd
(
q
,
K_Buffer
.
view
(
-
1
,
1
,
1
,
layer
.
qk_head_dim
),
o
,
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_q_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
K_Buffer
=
K_Buffer
.
view
(
-
1
,
1
,
layer
.
qk_head_dim
)
return
o
# self.extend_attention_fwd(
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
# k.contiguous(),
# v.contiguous(),
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
# self.forward_metadata.qo_indptr,
# kv_indptr,
# kv_indices,
# None,
# causal,
# None,
# self.forward_metadata.max_q_len,
# layer.scaling,
# layer.logit_cap,
# sliding_window_size,
# )
# return o
else
:
raise
ValueError
(
f
"Invalid forward mode for MLA prefill:
{
forward_batch
.
forward_mode
=
}
"
)
else
:
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
...
...
@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
self
.
forward_metadata
.
kv_last_page_len
,
self
.
forward_metadata
.
max_
extend
_len
,
self
.
forward_metadata
.
max_
q
_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
...
...
@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
self
.
kv_indices
=
None
self
.
qo_indptr
=
None
self
.
kv_last_page_len
=
None
self
.
max_
extend
_len
=
0
self
.
max_
prefix_extend
_len
=
0
self
.
max_
q
_len
=
0
self
.
max_
kv
_len
=
0
def
update
(
self
,
req_pool_indices
:
torch
.
Tensor
,
prefix
_lens
:
torch
.
Tensor
,
prefix
_lens_sum
:
int
,
kv
_lens
:
torch
.
Tensor
,
kv
_lens_sum
:
int
,
extend_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
max_q_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecInfo
],
):
# Keep the signature for type checking. It will be assigned during runtime.
...
...
@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
def
update_single_wrapper
(
self
,
req_pool_indices
:
torch
.
Tensor
,
prefix
_lens
:
torch
.
Tensor
,
prefix
_lens_sum
:
int
,
kv
_lens
:
torch
.
Tensor
,
kv
_lens_sum
:
int
,
extend_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
max_q_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecInfo
],
):
paged_kernel_lens
=
prefix_lens
paged_kernel_lens_sum
=
prefix_lens_sum
bs
=
len
(
req_pool_indices
)
kv_indptr
=
self
.
attn_backend
.
kv_indptr
if
spec_info
is
None
:
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel
_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv
_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel
_lens_sum
,
kv
_lens_sum
,
dtype
=
torch
.
int32
,
device
=
req_pool_indices
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel
_lens
,
kv
_lens
,
kv_indptr
,
None
,
kv_indices
,
...
...
@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
qo_indptr
=
self
.
attn_backend
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
extend_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
max_extend_len
=
torch
.
max
(
extend_lens
).
item
()
max_prefix_extend_len
=
torch
.
max
(
extend_lens
+
paged_kernel_lens
).
item
()
kv_indptr
+=
qo_indptr
else
:
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
paged_kernel
_lens
,
paged_kernel
_lens_sum
,
kv
_lens
,
kv
_lens_sum
,
self
.
req_to_token
,
)
)
...
...
@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
self
.
kv_indptr
=
kv_indptr
self
.
kv_indices
=
kv_indices
self
.
qo_indptr
=
qo_indptr
self
.
max_extend_len
=
max_extend_len
self
.
max_prefix_extend_len
=
max_prefix_extend_len
self
.
max_q_len
=
max_q_len
self
.
max_kv_len
=
max_kv_len
class
AiterMultiStepDraftBackend
:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.eagle_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_bs
+
1
,
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
self
.
attn_backends
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
AiterAttnBackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
page_size
=
model_runner
.
server_args
.
page_size
assert
self
.
page_size
==
1
,
"Page size must be 1"
def
common_template
(
self
,
forward_batch
:
ForwardBatch
,
kv_indices_buffer
:
torch
.
Tensor
,
call_fn
:
int
):
num_seqs
=
forward_batch
.
batch_size
bs
=
self
.
topk
*
num_seqs
seq_lens_sum
=
forward_batch
.
seq_lens_sum
self
.
generate_draft_decode_kv_indices
[
(
self
.
speculative_num_steps
,
num_seqs
,
self
.
topk
)
](
forward_batch
.
req_pool_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
seq_lens
,
kv_indices_buffer
,
self
.
kv_indptr
,
forward_batch
.
positions
,
self
.
pool_len
,
kv_indices_buffer
.
shape
[
1
],
self
.
kv_indptr
.
shape
[
1
],
triton
.
next_power_of_2
(
num_seqs
),
triton
.
next_power_of_2
(
self
.
speculative_num_steps
),
triton
.
next_power_of_2
(
bs
),
self
.
page_size
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
]
call_fn
(
i
,
forward_batch
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
empty
(
(
self
.
speculative_num_steps
,
forward_batch
.
batch_size
*
self
.
topk
*
self
.
max_context_len
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
def
call_fn
(
i
,
forward_batch
):
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
forward_batch
.
spec_info
.
kv_indices
=
(
forward_batch
.
spec_info
.
kv_indices
.
clone
()
)
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
self
.
common_template
(
forward_batch
,
kv_indices
,
call_fn
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
self
.
speculative_num_steps
,
max_num_tokens
*
self
.
max_context_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
None
,
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
python/sglang/srt/managers/schedule_batch.py
View file @
53f7874a
...
...
@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or
attention_backend_str
==
"cutlass_mla"
or
attention_backend_str
==
"ascend"
or
attention_backend_str
==
"trtllm_mha"
or
attention_backend_str
==
"aiter"
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
seq_lens_cpu
=
(
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
53f7874a
...
...
@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
(
AiterAttnBackend
,
AiterMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
AiterMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
AiterAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment