Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9ec314c6
Unverified
Commit
9ec314c6
authored
Aug 21, 2025
by
Qiaolin Yu
Committed by
GitHub
Aug 21, 2025
Browse files
Support speculative decoding in the trtllm_mha attention backend (#9331)
Co-authored-by:
ispobock
<
ispobaoke@gmail.com
>
parent
fedfe91c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
414 additions
and
33 deletions
+414
-33
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+388
-28
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-5
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+16
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
9ec314c6
...
...
@@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional
import
torch
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
FlashInferMultiStepDraftBackend
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
if
is_flashinfer_available
():
import
flashinfer
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
model_runner
:
ModelRunner
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
q_indptr_decode_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
speculative_step_id
:
int
=
0
,
):
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
q_indptr_decode_buf
)
super
().
__init__
(
model_runner
,
skip_prefill
,
kv_indptr_buf
,
kv_last_page_len_buf
)
config
=
model_runner
.
model_config
...
...
@@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
# CUDA graph state
self
.
decode_cuda_graph_metadata
=
{}
# Speculative decoding
# Only support topk <= 1 for now.
self
.
topk
=
model_runner
.
server_args
.
speculative_eagle_topk
or
0
self
.
speculative_step_id
=
speculative_step_id
self
.
target_verify_metadata
=
{}
self
.
speculative_num_draft_tokens
=
(
model_runner
.
server_args
.
speculative_num_draft_tokens
)
# Forward metadata
self
.
forward_metadata
:
Optional
[
TRTLLMMHAMetadata
]
=
None
...
...
@@ -97,11 +115,76 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Initialize CUDA graph state for TRTLLM MHA."""
max_num_pages
=
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
self
.
decode_cuda_graph_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
max_bs
,
(
self
.
max_context_len
+
self
.
page_size
-
1
)
//
self
.
page_size
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
}
if
(
self
.
speculative_num_draft_tokens
is
not
None
and
self
.
speculative_num_draft_tokens
>
0
):
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
]
=
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_k"
]
=
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
]
=
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
target_verify_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
arange
(
0
,
max_bs
*
self
.
speculative_num_draft_tokens
+
1
,
step
=
self
.
speculative_num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"strided_indices"
:
torch
.
arange
(
0
,
self
.
max_context_len
,
self
.
page_size
,
device
=
self
.
device
),
}
self
.
draft_extend_metadata
=
{
"cache_seqlens"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"cu_seqlens_q"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"page_table"
:
torch
.
zeros
(
max_bs
,
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
...
...
@@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
):
"""Initialize metadata for CUDA graph capture."""
metadata
=
TRTLLMMHAMetadata
()
device
=
seq_lens
.
device
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
not
None
:
# Draft Decode
# Here we only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
self
.
decode_cuda_graph_metadata
[
"cache_seqlens"
][:
bs
]
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_q
=
self
.
decode_cuda_graph_metadata
[
"cu_seqlens_q"
][
:
bs
+
1
]
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table_draft_decode"
][:
bs
,
:]
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
else
:
# Normal Decode
# Get sequence information
metadata
.
cache_seqlens_int32
=
seq_lens
[:
bs
].
to
(
torch
.
int32
)
batch_size
=
len
(
seq_lens
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
# Precompute maximum sequence length
metadata
.
max_seq_len_k
=
seq_lens
[:
bs
].
max
().
item
()
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
# Precompute cumulative sequence lengths
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
# Precompute page table
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][:
bs
,
:]
metadata
.
page_table
=
self
.
decode_cuda_graph_metadata
[
"page_table"
][
:
bs
,
:
]
self
.
decode_cuda_graph_metadata
[
bs
]
=
metadata
elif
forward_mode
.
is_target_verify
():
# Target Verify
# Here we only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
self
.
target_verify_metadata
[
"cache_seqlens"
][
:
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
,
self
.
speculative_num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
device
,
)
metadata
.
cu_seqlens_k
=
self
.
target_verify_metadata
[
"cu_seqlens_k"
][
:
(
bs
+
1
)
]
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_k
=
(
seq_lens
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
metadata
.
page_table
=
self
.
target_verify_metadata
[
"page_table"
][:
bs
,
:]
self
.
target_verify_metadata
[
bs
]
=
metadata
elif
forward_mode
.
is_draft_extend
():
metadata
.
cache_seqlens_int32
=
self
.
draft_extend_metadata
[
"cache_seqlens"
][
:
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
bs
*
num_tokens_per_bs
+
1
,
num_tokens_per_bs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
metadata
.
cu_seqlens_k
=
self
.
draft_extend_metadata
[
"cu_seqlens_k"
][
:
(
bs
+
1
)
]
num_tokens_per_bs
=
num_tokens
//
bs
metadata
.
max_seq_len_q
=
num_tokens_per_bs
metadata
.
max_seq_len_k
=
seq_lens
.
max
().
item
()
metadata
.
page_table
=
self
.
draft_extend_metadata
[
"page_table"
][:
bs
,
:]
self
.
draft_extend_metadata
[
bs
]
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -149,9 +321,23 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
req_pool_indices
=
req_pool_indices
[:
bs
]
device
=
seq_lens
.
device
metadata
=
None
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
not
None
:
# Draft Decode
# Here we only support topk = 1 for now.
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
max_len
+
self
.
speculative_step_id
+
1
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
+
self
.
speculative_step_id
+
1
)
else
:
# Normal Decode
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
max_len
=
seq_lens_cpu
.
max
().
item
()
...
...
@@ -159,9 +345,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata
.
max_seq_len_k
=
max_len
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
][
None
,
:],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
][
None
,
:
],
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
elif
forward_mode
.
is_target_verify
():
# Here we only support topk = 1 for now.
metadata
=
self
.
target_verify_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
(
seq_lens
+
self
.
speculative_num_draft_tokens
)
)
metadata
.
max_seq_len_k
=
(
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
decode_cuda_graph_metadata
[
"strided_indices"
][:
max_seq_pages
],
]
page_indices
//=
self
.
page_size
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
)
elif
forward_mode
.
is_draft_extend
():
metadata
=
self
.
draft_extend_metadata
[
bs
]
metadata
.
cache_seqlens_int32
.
copy_
(
seq_lens
)
metadata
.
max_seq_len_k
=
seq_lens_cpu
.
max
().
item
()
max_len
=
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
[
1
:].
copy_
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
accept_length
=
spec_info
.
accept_length
[:
bs
]
if
spec_info
.
accept_length_cpu
:
metadata
.
max_seq_len_q
=
max
(
spec_info
.
accept_length_cpu
)
+
1
else
:
metadata
.
max_seq_len_q
=
1
metadata
.
cu_seqlens_q
[
1
:].
copy_
(
torch
.
cumsum
(
accept_length
,
dim
=
0
,
dtype
=
torch
.
int32
)
)
max_seq_pages
=
(
metadata
.
max_seq_len_k
+
self
.
page_size
-
1
)
//
self
.
page_size
page_indices
=
self
.
req_to_token
[
req_pool_indices
[:,
None
],
self
.
draft_extend_metadata
[
"strided_indices"
][:
max_seq_pages
],
]
metadata
.
page_table
[:,
:
max_seq_pages
].
copy_
(
page_indices
//
self
.
page_size
)
self
.
forward_metadata
=
metadata
...
...
@@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
device
=
seqlens_in_batch
.
device
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
spec_info
is
not
None
:
# Draft Decode
# Here we only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
(
seqlens_in_batch
+
(
self
.
speculative_step_id
+
1
)
).
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
(
self
.
speculative_step_id
+
1
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
else
:
# Normal Decode
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# Only support topk = 1 for now.
metadata
.
cache_seqlens_int32
=
(
forward_batch
.
seq_lens
+
self
.
speculative_num_draft_tokens
).
to
(
torch
.
int32
)
metadata
.
max_seq_len_q
=
self
.
speculative_num_draft_tokens
metadata
.
max_seq_len_k
=
(
forward_batch
.
seq_lens_cpu
.
max
().
item
()
+
self
.
speculative_num_draft_tokens
)
metadata
.
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
*
self
.
speculative_num_draft_tokens
+
1
,
self
.
speculative_num_draft_tokens
,
dtype
=
torch
.
int32
,
device
=
device
,
)
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
cache_seqlens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
else
:
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
...
...
@@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
if
any
(
forward_batch
.
extend_prefix_lens_cpu
):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
...
@@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
class
TRTLLMHAAttnMultiStepDraftBackend
(
FlashInferMultiStepDraftBackend
):
"""Multi-step TRTLLM MHA attention kernel used by EAGLE."""
def
__init__
(
self
,
model_runner
:
ModelRunner
,
topk
:
int
,
speculative_num_steps
:
int
):
super
().
__init__
(
model_runner
,
topk
,
speculative_num_steps
)
for
i
in
range
(
speculative_num_steps
):
self
.
attn_backends
[
i
]
=
TRTLLMHAAttnBackend
(
model_runner
,
skip_prefill
=
True
,
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
kv_last_page_len_buf
=
self
.
kv_last_page_len
,
speculative_step_id
=
i
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
max_num_tokens
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
encoder_lens
=
forward_batch
.
encoder_lens
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
python/sglang/srt/server_args.py
View file @
9ec314c6
...
...
@@ -500,11 +500,6 @@ class ServerArgs:
)
self
.
page_size
=
64
if
self
.
speculative_algorithm
is
not
None
:
raise
ValueError
(
"trtllm_mha backend does not support speculative decoding yet."
)
if
self
.
attention_backend
==
"dual_chunk_flash_attn"
:
logger
.
warning
(
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
...
...
@@ -653,6 +648,16 @@ class ServerArgs:
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
self
)
if
(
self
.
attention_backend
==
"trtllm_mha"
or
self
.
decode_attention_backend
==
"trtllm_mha"
or
self
.
prefill_attention_backend
==
"trtllm_mha"
):
if
self
.
speculative_eagle_topk
>
1
:
raise
ValueError
(
"trtllm_mha backend only supports topk = 1 for speculative decoding."
)
if
(
self
.
speculative_eagle_topk
==
1
and
self
.
speculative_num_draft_tokens
!=
self
.
speculative_num_steps
+
1
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
9ec314c6
...
...
@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
speculative_num_steps
,
)
elif
self
.
server_args
.
attention_backend
==
"trtllm_mha"
:
from
sglang.srt.layers.attention.trtllm_mha_backend
import
(
TRTLLMHAAttnBackend
,
TRTLLMHAAttnMultiStepDraftBackend
,
)
self
.
draft_attn_backend
=
TRTLLMHAAttnMultiStepDraftBackend
(
self
.
draft_model_runner
,
self
.
topk
,
self
.
speculative_num_steps
,
)
self
.
draft_extend_attn_backend
=
TRTLLMHAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"trtllm_mla"
:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment