Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9ec314c6
Unverified
Commit
9ec314c6
authored
Aug 21, 2025
by
Qiaolin Yu
Committed by
GitHub
Aug 21, 2025
Browse files
Support speculative decoding in the trtllm_mha attention backend (#9331)
Co-authored-by:
ispobock
<
ispobaoke@gmail.com
>
parent
fedfe91c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
414 additions
and
33 deletions
+414
-33
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+388
-28
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-5
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+16
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
9ec314c6
...
@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
...
@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
torch
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
FlashInferMultiStepDraftBackend
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
if
is_flashinfer_available
():
import
flashinfer
import
flashinfer
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
q_indptr_decode_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
speculative_step_id
:
int
=
0
,
):
):
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
q_indptr_decode_buf
)
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
kv_last_page_len_buf
)
config
=
model_runner
.
model_config
config
=
model_runner
.
model_config
...
@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# CUDA graph state
# CUDA graph state
self
.
decode_cuda_graph_metadata
=
{}
self
.
decode_cuda_graph_metadata
=
{}
# Speculative decoding
# Only support topk <= 1 for now.
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
or
0
self
.
speculative_step_id
=
speculative_step_id
self
.
target_verify_metadata
=
{}
self
.
speculative_num_draft_tokens
=
(
model_runner
.
server_args
.
speculative_num_draft_tokens
)
# Forward metadata
# Forward metadata
self
.
forward_metadata
:
Optional
[
TRTLLMMHAMetadata
]
=
None
self
.
forward_metadata
:
Optional
[
TRTLLMMHAMetadata
]
=
None
...
@@ -97,11 +115,76 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -97,11 +115,76 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
"""Initialize CUDA graph state for TRTLLM MHA."""
"""Initialize CUDA graph state for TRTLLM MHA."""
max_num_pages
=
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
self
.
decode_cuda_graph_metadata
=
{
self
.
decode_cuda_graph_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
"page_table"
:
torch
.
zeros
(
max_bs
,
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
}
if
(
self
.
speculative_num_draft_tokens
is
not
None
and
self
.
speculative_num_draft_tokens
>
0
):
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
]
=
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_k"
]
=
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
]
=
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
target_verify_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
arange
(
0
,
max_bs
*
self
.
speculative_num_draft_tokens
+
1
,
step
=
self
.
speculative_num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
}
self
.
draft_extend_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
),
...
@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
):
):
"""Initialize metadata for CUDA graph capture."""
"""Initialize metadata for CUDA graph capture."""
metadata
=
TRTLLMMHAMetadata
()
metadata
=
TRTLLMMHAMetadata
()
device
=
seq_lens
.
device
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
not
None
:
# Draft Decode
# Here we only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
self
.
decode_cuda_graph_metadata
[
"cache_seqlens"
][:
bs
]
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][
:
bs
+
1
]
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
][:
bs
,
:]
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
else
:
# Normal Decode
# Get sequence information
# Get sequence information
metadata
.
cache_seqlens_int32
=
seq_lens
[:
bs
].
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seq_lens
[:
bs
].
to
(
torch
.
int32
)
batch_size
=
len
(
seq_lens
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
# Precompute maximum sequence length
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
seq_lens
[:
bs
].
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
# Precompute cumulative sequence lengths
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
# Precompute page table
# Precompute page table
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][
:
bs
,
:
]
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
elif
forward_mode
.
is_target_verify
():
# Target Verify
# Here we only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
self
.
target_verify_metadata
[
"cache_seqlens"
][
:
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
,
self
.
speculative_num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
device
,
)
metadata
.
cu_seqlens_k
=
self
.
target_verify_metadata
[
"cu_seqlens_k"
][
:
(
bs
+
1
)
]
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_k
=
(
seq_lens
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
metadata
.
page_table
=
self
.
target_verify_metadata
[
"page_table"
][:
bs
,
:]
self
.
target_verify_metadata
[
bs
]
=
metadata
elif
forward_mode
.
is_draft_extend
():
metadata
.
cache_seqlens_int32
=
self
.
draft_extend_metadata
[
"cache_seqlens"
][
:
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
metadata
.
cu_seqlens_k
=
self
.
draft_extend_metadata
[
"cu_seqlens_k"
][
:
(
bs
+
1
)
]
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
metadata
.
page_table
=
self
.
draft_extend_metadata
[
"page_table"
][:
bs
,
:]
self
.
draft_extend_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -149,9 +321,23 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -149,9 +321,23 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens
=
seq_lens
[:
bs
]
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
device
=
seq_lens
.
device
metadata
=
None
metadata
=
None
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
not
None
:
# Draft Decode
# Here we only support topk = 1 for now.
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
max_len
+
self
.
speculative_step_id
+
1
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
+
self
.
speculative_step_id
+
1
)
else
:
# Normal Decode
# Normal Decode
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
max_len
=
seq_lens_cpu
.
max
().
item
()
...
@@ -159,9 +345,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -159,9 +345,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata
.
max_seq_len_k
=
max_len
metadata
.
max_seq_len_k
=
max_len
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
][
None
,
:],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
][
None
,
:
],
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
elif
forward_mode
.
is_target_verify
():
# Here we only support topk = 1 for now.
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
metadata
.
max_seq_len_k
=
(
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
],
]
page_indices
//=
self
.
page_size
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
elif
forward_mode
.
is_draft_extend
():
metadata
=
self
.
draft_extend_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
accept_length
=
spec_info
.
accept_length
[:
bs
]
if
spec_info
.
accept_length_cpu
:
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
)
+
1
else
:
metadata
.
max_seq_len_q
=
1
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
draft_extend_metadata
[
"strided_indices"
][:
max_seq_pages
],
]
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
...
@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
device
=
seqlens_in_batch
.
device
device
=
seqlens_in_batch
.
device
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
spec_info
is
not
None
:
# Draft Decode
# Here we only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
(
seqlens_in_batch
+
(
self
.
speculative_step_id
+
1
)
).
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
else
:
# Normal Decode
# Normal Decode
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# Only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
(
forward_batch
.
seq_lens
+
self
.
speculative_num_draft_tokens
).
to
(
torch
.
int32
)
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_k
=
(
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
*
self
.
speculative_num_draft_tokens
+
1
,
self
.
speculative_num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
device
,
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
else
:
else
:
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
...
@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
if
any
(
forward_batch
.
extend_prefix_lens_cpu
):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
class
TRTLLMHAAttnMultiStepDraftBackend
(
FlashInferMultiStepDraftBackend
):
"""Multi-step TRTLLM MHA attention kernel used by EAGLE."""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
speculative_num_steps
):
self
.
attn_backends
[
i
]
=
TRTLLMHAAttnBackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
kv_last_page_len_buf
=
self
.
kv_last_page_len
,
speculative_step_id
=
i
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
encoder_lens
=
forward_batch
.
encoder_lens
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
python/sglang/srt/server_args.py
View file @
9ec314c6
...
@@ -500,11 +500,6 @@ class ServerArgs:
...
@@ -500,11 +500,6 @@ class ServerArgs:
)
)
self
.
page_size
=
64
self
.
page_size
=
64
if
self
.
speculative_algorithm
is
not
None
:
raise
ValueError
(
"trtllm_mha backend does not support speculative decoding yet."
)
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
logger
.
warning
(
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
...
@@ -653,6 +648,16 @@ class ServerArgs:
...
@@ -653,6 +648,16 @@ class ServerArgs:
self
.
speculative_num_draft_tokens
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
self
)
)
=
auto_choose_speculative_params
(
self
)
if
(
self
.
attention_backend
==
"trtllm_mha"
or
self
.
decode_attention_backend
==
"trtllm_mha"
or
self
.
prefill_attention_backend
==
"trtllm_mha"
):
if
self
.
speculative_eagle_topk
>
1
:
raise
ValueError
(
"trtllm_mha backend only supports topk = 1 for speculative decoding."
)
if
(
if
(
self
.
speculative_eagle_topk
==
1
self
.
speculative_eagle_topk
==
1
and
self
.
speculative_num_draft_tokens
!=
self
.
speculative_num_steps
+
1
and
self
.
speculative_num_draft_tokens
!=
self
.
speculative_num_steps
+
1
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
9ec314c6
...
@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
...
@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
elif
self
.
server_args
.
attention_backend
==
"trtllm_mha"
:
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnBackend
,
TRTLLMHAAttnMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"trtllm_mla"
:
elif
self
.
server_args
.
attention_backend
==
"trtllm_mla"
:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
raise
ValueError
(
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment