Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
73d4a5f8
Unverified
Commit
73d4a5f8
authored
Oct 01, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 01, 2025
Browse files
Organize spec-related data structures (#10735)
parent
7fb551a7
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
866 additions
and
808 deletions
+866
-808
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+10
-38
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+1
-1
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+2
-1
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+183
-750
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+4
-2
python/sglang/srt/speculative/ngram_utils.py
python/sglang/srt/speculative/ngram_utils.py
+9
-4
python/sglang/srt/speculative/ngram_worker.py
python/sglang/srt/speculative/ngram_worker.py
+1
-5
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+42
-0
python/sglang/srt/speculative/spec_utils.py
python/sglang/srt/speculative/spec_utils.py
+607
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+5
-4
test/srt/test_forward_split_prefill.py
test/srt/test_forward_split_prefill.py
+1
-2
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
73d4a5f8
...
@@ -821,7 +821,7 @@ class CudaGraphRunner:
...
@@ -821,7 +821,7 @@ class CudaGraphRunner:
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
or
self
.
model_runner
.
spec_algorithm
.
is_standalone
()
or
self
.
model_runner
.
spec_algorithm
.
is_standalone
()
):
):
from
sglang.srt.speculative.eagle_
utils
import
EagleVerifyInput
from
sglang.srt.speculative.eagle_
info
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
raise
RuntimeError
(
"This should not happen."
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
73d4a5f8
...
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size
,
get_attention_tp_size
,
set_dp_buffer_len
,
set_dp_buffer_len
,
)
)
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
from
sglang.srt.utils
import
(
flatten_nested_list
,
get_compiler_backend
,
is_npu
,
support_triton
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
...
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
...
@@ -293,7 +286,7 @@ class ForwardBatch:
...
@@ -293,7 +286,7 @@ class ForwardBatch:
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
# Speculative decoding
# Speculative decoding
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraft
Input
]
]
=
None
spec_info
:
Optional
[
Spec
Input
]
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
...
@@ -364,33 +357,14 @@ class ForwardBatch:
...
@@ -364,33 +357,14 @@ class ForwardBatch:
# For MLP sync
# For MLP sync
if
batch
.
global_num_tokens
is
not
None
:
if
batch
.
global_num_tokens
is
not
None
:
from
sglang.srt.speculative.eagle_utils
import
(
EagleDraftInput
,
EagleVerifyInput
,
)
assert
batch
.
global_num_tokens_for_logprob
is
not
None
assert
batch
.
global_num_tokens_for_logprob
is
not
None
# process global_num_tokens and global_num_tokens_for_logprob
# process global_num_tokens and global_num_tokens_for_logprob
if
batch
.
spec_info
is
not
None
:
if
batch
.
spec_info
is
not
None
:
if
isinstance
(
batch
.
spec_info
,
EagleDraftInput
):
spec_info
:
SpecInput
=
batch
.
spec_info
global_num_tokens
=
[
global_num_tokens
,
global_num_tokens_for_logprob
=
(
x
*
batch
.
spec_info
.
num_tokens_per_batch
spec_info
.
get_spec_adjusted_global_num_tokens
(
batch
)
for
x
in
batch
.
global_num_tokens
)
]
global_num_tokens_for_logprob
=
[
x
*
batch
.
spec_info
.
num_tokens_for_logprob_per_batch
for
x
in
batch
.
global_num_tokens_for_logprob
]
else
:
assert
isinstance
(
batch
.
spec_info
,
EagleVerifyInput
)
global_num_tokens
=
[
x
*
batch
.
spec_info
.
draft_token_num
for
x
in
batch
.
global_num_tokens
]
global_num_tokens_for_logprob
=
[
x
*
batch
.
spec_info
.
draft_token_num
for
x
in
batch
.
global_num_tokens_for_logprob
]
else
:
else
:
global_num_tokens
=
batch
.
global_num_tokens
global_num_tokens
=
batch
.
global_num_tokens
global_num_tokens_for_logprob
=
batch
.
global_num_tokens_for_logprob
global_num_tokens_for_logprob
=
batch
.
global_num_tokens_for_logprob
...
@@ -669,9 +643,6 @@ class ForwardBatch:
...
@@ -669,9 +643,6 @@ class ForwardBatch:
)
)
def
prepare_mlp_sync_batch
(
self
,
model_runner
:
ModelRunner
):
def
prepare_mlp_sync_batch
(
self
,
model_runner
:
ModelRunner
):
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
assert
self
.
global_num_tokens_cpu
is
not
None
assert
self
.
global_num_tokens_cpu
is
not
None
assert
self
.
global_num_tokens_for_logprob_cpu
is
not
None
assert
self
.
global_num_tokens_for_logprob_cpu
is
not
None
...
@@ -768,7 +739,8 @@ class ForwardBatch:
...
@@ -768,7 +739,8 @@ class ForwardBatch:
if
self
.
extend_seq_lens
is
not
None
:
if
self
.
extend_seq_lens
is
not
None
:
self
.
extend_seq_lens
=
self
.
_pad_tensor_to_size
(
self
.
extend_seq_lens
,
bs
)
self
.
extend_seq_lens
=
self
.
_pad_tensor_to_size
(
self
.
extend_seq_lens
,
bs
)
if
self
.
spec_info
is
not
None
and
isinstance
(
self
.
spec_info
,
EagleDraftInput
):
if
self
.
spec_info
is
not
None
and
self
.
spec_info
.
is_draft_input
():
# FIXME(lsyin): remove this isinstance logic
spec_info
=
self
.
spec_info
spec_info
=
self
.
spec_info
self
.
output_cache_loc_backup
=
self
.
out_cache_loc
self
.
output_cache_loc_backup
=
self
.
out_cache_loc
self
.
hidden_states_backup
=
spec_info
.
hidden_states
self
.
hidden_states_backup
=
spec_info
.
hidden_states
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
73d4a5f8
...
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.speculative.eagle_
utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_
info
import
EagleDraftInput
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_attn_tp_gather
,
require_gathered_buffer
,
require_gathered_buffer
,
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
73d4a5f8
...
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.spec_utils
import
fast_topk
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_attn_tp_gather
,
require_gathered_buffer
,
require_gathered_buffer
,
...
...
python/sglang/srt/speculative/eagle_
utils
.py
→
python/sglang/srt/speculative/eagle_
info
.py
View file @
73d4a5f8
from
__future__
import
annotations
import
copy
import
logging
import
logging
import
os
from
copy
import
copy
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Req
,
ScheduleBatch
,
ScheduleBatch
,
get_last_loc
,
get_last_loc
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
from
sglang.srt.speculative.spec_utils
import
(
SIMULATE_ACC_LEN
,
TREE_SPEC_KERNEL_AVAILABLE
,
_generate_simulated_accept_index
,
align_evict_mask_to_page_size
,
assign_req_to_token_pool
,
create_accept_length_filter
,
create_extend_after_decode_spec_info
,
filter_finished_cache_loc_kernel
,
get_src_tgt_cache_loc
,
get_target_cache_loc
,
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
fast_topk
,
top_k_renorm_prob
,
top_k_renorm_prob
,
top_p_renorm_prob
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
verify_tree_greedy
,
)
)
elif
is_hip
():
elif
is_hip
():
from
sgl_kernel
import
fast_topk
,
verify_tree_greedy
from
sgl_kernel
import
verify_tree_greedy
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN
=
envs
.
SGLANG_SIMULATE_ACC_LEN
.
get
()
# turn off if < 0
SIMULATE_ACC_METHOD
=
envs
.
SGLANG_SIMULATE_ACC_METHOD
.
get
()
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE
=
"tree_speculative_sampling_target_only"
in
globals
()
@
dataclass
class
EagleDraftInput
:
# The inputs for decode
# shape: (b, topk)
topk_p
:
torch
.
Tensor
=
None
topk_index
:
torch
.
Tensor
=
None
# shape: (b, hidden_size)
hidden_states
:
torch
.
Tensor
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
FULL
# Inputs for extend
# shape: (b,)
verified_id
:
torch
.
Tensor
=
None
accept_length
:
torch
.
Tensor
=
None
accept_length_cpu
:
List
[
int
]
=
None
# Inputs for the attention backends
# shape: (b + 1,)
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
# Shape info for padding
num_tokens_per_batch
:
int
=
-
1
num_tokens_for_logprob_per_batch
:
int
=
-
1
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend
:
torch
.
Tensor
=
None
req_pool_indices_for_draft_extend
:
torch
.
Tensor
=
None
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
if
batch
.
forward_mode
.
is_idle
():
return
# Prefill only generate 1 token.
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
pt
=
0
for
i
,
extend_len
in
enumerate
(
batch
.
extend_lens
):
input_ids
=
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
=
torch
.
cat
(
(
input_ids
[
1
:],
self
.
verified_id
[
i
].
reshape
(
1
))
)
pt
+=
extend_len
@
classmethod
def
create_idle_input
(
cls
,
device
:
torch
.
device
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
topk
:
int
,
capture_hidden_mode
:
CaptureHiddenMode
,
):
return
cls
(
verified_id
=
torch
.
empty
((
0
,),
device
=
device
,
dtype
=
torch
.
int32
),
hidden_states
=
torch
.
empty
((
0
,
hidden_size
),
device
=
device
,
dtype
=
dtype
),
topk_p
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
float32
),
topk_index
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
int64
),
capture_hidden_mode
=
capture_hidden_mode
,
accept_length
=
torch
.
empty
((
0
,),
device
=
device
,
dtype
=
torch
.
int32
),
accept_length_cpu
=
[],
)
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
speculative_num_steps
:
int
,
):
if
batch
.
forward_mode
.
is_idle
():
return
batch
.
input_ids
=
self
.
verified_id
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
batch
.
return_logprob
=
False
batch
.
return_hidden_states
=
False
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
self
.
accept_length
.
add_
(
1
)
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
self
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
):
bs
=
self
.
accept_length
.
numel
()
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
if
paged_kernel_lens_sum
is
None
:
paged_kernel_lens_sum
=
cum_kv_seq_len
[
-
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
cum_kv_seq_len
,
None
,
kv_indices
,
req_to_token
.
size
(
1
),
)
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
None
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
,
has_been_filtered
:
bool
=
True
):
if
has_been_filtered
:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
if
len
(
new_indices
)
!=
len
(
self
.
topk_p
):
logger
.
warning
(
f
"length of new_indices:
{
len
(
new_indices
)
}
!= length of topk_p:
{
len
(
self
.
topk_p
)
}
, this should not happen"
)
self
.
topk_p
=
self
.
topk_p
[:
len
(
new_indices
)]
self
.
topk_index
=
self
.
topk_index
[:
len
(
new_indices
)]
self
.
hidden_states
=
self
.
hidden_states
[:
len
(
new_indices
)]
self
.
verified_id
=
self
.
verified_id
[:
len
(
new_indices
)]
else
:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self
.
topk_p
=
self
.
topk_p
[
new_indices
]
self
.
topk_index
=
self
.
topk_index
[
new_indices
]
self
.
hidden_states
=
self
.
hidden_states
[
new_indices
]
self
.
verified_id
=
self
.
verified_id
[
new_indices
]
def
merge_batch
(
self
,
spec_info
:
EagleDraftInput
):
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
spec_info
.
hidden_states
self
.
verified_id
=
spec_info
.
verified_id
self
.
topk_p
=
spec_info
.
topk_p
self
.
topk_index
=
spec_info
.
topk_index
return
if
spec_info
.
hidden_states
is
None
:
return
self
.
hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
spec_info
.
hidden_states
],
axis
=
0
)
self
.
verified_id
=
torch
.
cat
([
self
.
verified_id
,
spec_info
.
verified_id
],
axis
=
0
)
self
.
topk_p
=
torch
.
cat
([
self
.
topk_p
,
spec_info
.
topk_p
])
self
.
topk_index
=
torch
.
cat
([
self
.
topk_index
,
spec_info
.
topk_index
])
@
dataclass
@
dataclass
class
EagleVerifyOutput
:
class
EagleVerifyInput
(
SpecInput
):
# Draft input batch
draft_input
:
EagleDraftInput
# Logit outputs from target worker
logits_output
:
LogitsProcessorOutput
# Accepted token ids including the bonus token
verified_id
:
torch
.
Tensor
# Accepted token length per sequence in a batch in CPU.
accept_length_per_req_cpu
:
List
[
int
]
# Accepted indices from logits_output.next_token_logits
accepted_indices
:
torch
.
Tensor
@
dataclass
class
EagleVerifyInput
:
draft_token
:
torch
.
Tensor
draft_token
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
positions
:
torch
.
Tensor
positions
:
torch
.
Tensor
...
@@ -245,6 +62,12 @@ class EagleVerifyInput:
...
@@ -245,6 +62,12 @@ class EagleVerifyInput:
seq_lens_cpu
:
torch
.
Tensor
seq_lens_cpu
:
torch
.
Tensor
grammar
:
BaseGrammarObject
=
None
grammar
:
BaseGrammarObject
=
None
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_VERIFY
)
def
get_spec_adjust_token_coefficient
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
draft_token_num
,
self
.
draft_token_num
@
classmethod
@
classmethod
def
create_idle_input
(
cls
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
):
def
create_idle_input
(
cls
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
):
return
cls
(
return
cls
(
...
@@ -724,574 +547,184 @@ class EagleVerifyInput:
...
@@ -724,574 +547,184 @@ class EagleVerifyInput:
)
)
@
triton
.
jit
@
dataclass
def
create_extend_after_decode_spec_info
(
class
EagleDraftInput
(
SpecInput
):
verified_id
,
# The inputs for decode
seq_lens
,
# shape: (b, topk)
accept_lens
,
topk_p
:
torch
.
Tensor
=
None
positions
,
topk_index
:
torch
.
Tensor
=
None
new_verified_id
,
# shape: (b, hidden_size)
bs_upper
:
tl
.
constexpr
,
hidden_states
:
torch
.
Tensor
=
None
):
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
FULL
pid
=
tl
.
program_id
(
axis
=
0
)
offsets
=
tl
.
arange
(
0
,
bs_upper
)
seq_length
=
tl
.
load
(
seq_lens
+
pid
)
accept_length
=
tl
.
load
(
accept_lens
+
pid
)
accept_len_cumsum
=
tl
.
sum
(
tl
.
load
(
accept_lens
+
offsets
,
mask
=
offsets
<
pid
,
other
=
0
)
)
positions_ptr
=
positions
+
accept_len_cumsum
mask
=
offsets
<
accept_length
tl
.
store
(
positions_ptr
+
offsets
,
seq_length
-
accept_length
+
offsets
,
mask
)
accept_len_cumsum
+=
accept_length
-
1
verified_id_data
=
tl
.
load
(
verified_id
+
accept_len_cumsum
)
tl
.
store
(
new_verified_id
+
pid
,
verified_id_data
)
@
triton
.
jit
def
assign_req_to_token_pool
(
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
pool_len
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
32
pid
=
tl
.
program_id
(
axis
=
0
)
kv_start
=
tl
.
load
(
start_offset
+
pid
)
kv_end
=
tl
.
load
(
end_offset
+
pid
)
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
length_offset
=
tl
.
arange
(
0
,
bs_upper
)
start
=
tl
.
load
(
start_offset
+
length_offset
,
mask
=
length_offset
<
pid
,
other
=
0
)
end
=
tl
.
load
(
end_offset
+
length_offset
,
mask
=
length_offset
<
pid
,
other
=
0
)
out_offset
=
tl
.
sum
(
end
-
start
,
axis
=
0
)
out_cache_ptr
=
out_cache_loc
+
out_offset
save_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
kv_start
load_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
save_offset
<
kv_end
data
=
tl
.
load
(
out_cache_ptr
+
load_offset
,
mask
=
mask
)
tl
.
store
(
token_pool
+
save_offset
,
data
,
mask
=
mask
)
save_offset
+=
BLOCK_SIZE
load_offset
+=
BLOCK_SIZE
@
triton
.
jit
def
assign_draft_cache_locs
(
req_pool_indices
,
req_to_token
,
seq_lens
,
extend_lens
,
num_new_pages_per_topk
,
out_cache_loc
,
pool_len
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
speculative_num_steps
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
iter_upper
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
128
pid
=
tl
.
program_id
(
axis
=
0
)
if
page_size
==
1
or
topk
==
1
:
copy_len
=
topk
*
speculative_num_steps
out_cache_ptr
=
out_cache_loc
+
pid
*
topk
*
speculative_num_steps
else
:
bs_offset
=
tl
.
arange
(
0
,
bs_upper
)
copy_len
=
tl
.
load
(
extend_lens
+
pid
)
cum_copy_len
=
tl
.
sum
(
tl
.
load
(
extend_lens
+
bs_offset
,
mask
=
bs_offset
<
pid
))
out_cache_ptr
=
out_cache_loc
+
cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start
=
tl
.
load
(
seq_lens
+
pid
)
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
num_loop
=
tl
.
cdiv
(
copy_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
copy_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
copy_offset
<
copy_len
data
=
tl
.
load
(
out_cache_ptr
+
copy_offset
,
mask
=
mask
)
tl
.
store
(
token_pool
+
kv_start
+
copy_offset
,
data
,
mask
=
mask
)
if
page_size
==
1
or
topk
==
1
:
return
# Part 2: Copy the indices for the last partial page
prefix_len
=
tl
.
load
(
seq_lens
+
pid
)
last_page_len
=
prefix_len
%
page_size
offsets
=
tl
.
arange
(
0
,
page_size
)
mask
=
offsets
<
last_page_len
num_new_pages_per_topk_
=
tl
.
load
(
num_new_pages_per_topk
+
pid
)
prefix_base
=
token_pool
+
prefix_len
-
last_page_len
for
topk_id
in
range
(
topk
):
value
=
tl
.
load
(
prefix_base
+
offsets
,
mask
=
mask
)
tl
.
store
(
prefix_base
+
topk_id
*
num_new_pages_per_topk_
*
page_size
+
offsets
,
value
,
mask
=
mask
,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest
=
tl
.
arange
(
0
,
iter_upper
)
for
topk_id
in
range
(
topk
):
indices
=
tl
.
load
(
prefix_base
+
topk_id
*
num_new_pages_per_topk_
*
page_size
+
last_page_len
+
iter_offest
,
mask
=
iter_offest
<
speculative_num_steps
,
)
tl
.
store
(
out_cache_loc
+
pid
*
topk
*
speculative_num_steps
+
topk_id
*
speculative_num_steps
+
iter_offest
,
indices
,
mask
=
iter_offest
<
speculative_num_steps
,
)
# Inputs for extend
# shape: (b,)
verified_id
:
torch
.
Tensor
=
None
accept_length
:
torch
.
Tensor
=
None
accept_length_cpu
:
List
[
int
]
=
None
@
triton
.
jit
# Inputs for the attention backends
def
generate_draft_decode_kv_indices
(
# shape: (b + 1,)
req_pool_indices
,
kv_indptr
:
torch
.
Tensor
=
None
req_to_token
,
kv_indices
:
torch
.
Tensor
=
None
paged_kernel_lens
,
kv_indices
,
kv_indptr
,
positions
,
pool_len
:
tl
.
constexpr
,
kv_indices_stride
:
tl
.
constexpr
,
kv_indptr_stride
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
iter_upper
:
tl
.
constexpr
,
num_tokens_upper
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
128
iters
=
tl
.
program_id
(
axis
=
0
)
bid
=
tl
.
program_id
(
axis
=
1
)
topk_id
=
tl
.
program_id
(
axis
=
2
)
num_steps
=
tl
.
num_programs
(
axis
=
0
)
num_seqs
=
tl
.
num_programs
(
axis
=
1
)
topk
=
tl
.
num_programs
(
axis
=
2
)
kv_indices
+=
kv_indices_stride
*
iters
kv_indptr
+=
kv_indptr_stride
*
iters
iters
+=
1
load_offset
=
tl
.
arange
(
0
,
bs_upper
)
seq_lens
=
tl
.
load
(
paged_kernel_lens
+
load_offset
,
mask
=
load_offset
<
bid
,
other
=
0
)
seq_len
=
tl
.
load
(
paged_kernel_lens
+
bid
)
cum_seq_len
=
tl
.
sum
(
seq_lens
)
# Update kv_indices
kv_offset
=
cum_seq_len
*
topk
+
bid
*
iters
*
topk
+
topk_id
*
(
seq_len
+
iters
)
kv_ptr
=
kv_indices
+
kv_offset
token_pool_ptr
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
bid
)
*
pool_len
kv_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
seq_len
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
kv_offset
<
seq_len
data
=
tl
.
load
(
token_pool_ptr
+
kv_offset
,
mask
=
mask
)
tl
.
store
(
kv_ptr
+
kv_offset
,
data
,
mask
=
mask
)
kv_offset
+=
BLOCK_SIZE
extend_offset
=
tl
.
arange
(
0
,
iter_upper
)
if
page_size
==
1
or
topk
==
1
:
extend_data
=
tl
.
load
(
token_pool_ptr
+
seq_len
+
topk_id
*
num_steps
+
tl
.
arange
(
0
,
iter_upper
),
mask
=
extend_offset
<
iters
,
)
else
:
prefix_len
=
seq_len
last_page_len
=
prefix_len
%
page_size
num_new_pages_per_topk
=
(
last_page_len
+
num_steps
+
page_size
-
1
)
//
page_size
prefix_base
=
seq_len
//
page_size
*
page_size
start
=
(
prefix_base
+
topk_id
*
num_new_pages_per_topk
*
page_size
+
last_page_len
)
extend_data
=
tl
.
load
(
token_pool_ptr
+
start
+
extend_offset
,
mask
=
extend_offset
<
iters
,
)
tl
.
store
(
kv_ptr
+
seq_len
+
extend_offset
,
extend_data
,
mask
=
extend_offset
<
iters
)
# Shape info for padding
num_tokens_per_batch
:
int
=
-
1
# Update kv_indptr
num_tokens_for_logprob_per_batch
:
int
=
-
1
bs_offset
=
tl
.
arange
(
0
,
num_tokens_upper
)
zid
=
bid
*
topk
+
topk_id
if
zid
==
0
:
zid
=
num_seqs
*
topk
positions
=
tl
.
load
(
positions
+
bs_offset
,
mask
=
bs_offset
<
zid
,
other
=
0
)
base
=
tl
.
sum
(
positions
)
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
@
triton
.
jit
def
align_evict_mask_to_page_size
(
seq_lens
,
evict_mask
,
page_size
:
tl
.
constexpr
,
num_draft_tokens
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
t_range
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
bid
=
tl
.
program_id
(
axis
=
0
)
seq_len
=
tl
.
load
(
seq_lens
+
bid
)
io_mask
=
t_range
<
num_draft_tokens
mask_row
=
tl
.
load
(
evict_mask
+
bid
*
num_draft_tokens
+
t_range
,
mask
=
io_mask
,
other
=
0
)
num_trues
=
tl
.
sum
(
mask_row
)
# Inputs for draft extend
num_false
=
num_draft_tokens
-
num_trues
# shape: (b,)
seq_lens_for_draft_extend
:
torch
.
Tensor
=
None
start
=
(
seq_len
+
num_false
-
1
)
//
page_size
*
page_size
-
seq_len
req_pool_indices_for_draft_extend
:
torch
.
Tensor
=
None
for
i
in
range
(
max
(
start
,
0
),
min
(
start
+
page_size
,
num_draft_tokens
)):
tl
.
store
(
evict_mask
+
bid
*
num_draft_tokens
+
i
,
False
)
@
triton
.
jit
def
get_target_cache_loc
(
tgt_cache_loc
,
to_free_slots
,
accept_length
,
to_free_num_slots
,
out_cache_loc
,
num_verify_tokens
:
tl
.
constexpr
,
num_verify_tokens_upper
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
):
bid
=
tl
.
program_id
(
axis
=
0
)
offset
=
tl
.
arange
(
0
,
num_verify_tokens_upper
)
bs_offset
=
tl
.
arange
(
0
,
bs_upper
)
# write the first part to tgt_cache_loc
accept_len_all
=
tl
.
load
(
accept_length
+
bs_offset
,
mask
=
bs_offset
<
bid
)
tgt_cache_loc_start
=
tl
.
sum
(
accept_len_all
)
+
bid
copy_len
=
tl
.
load
(
accept_length
+
bid
)
+
1
out_cache_loc_row
=
tl
.
load
(
out_cache_loc
+
bid
*
num_verify_tokens
+
offset
,
mask
=
offset
<
copy_len
)
tl
.
store
(
tgt_cache_loc
+
tgt_cache_loc_start
+
offset
,
out_cache_loc_row
,
mask
=
offset
<
copy_len
,
)
# write the second part to to_free_num_pages
def
__post_init__
(
self
):
to_free_num_slots_all
=
tl
.
load
(
to_free_num_slots
+
bs_offset
,
mask
=
bs_offset
<
bid
)
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
to_free_num_slots_cur
=
tl
.
load
(
to_free_num_slots
+
bid
)
out_cache_loc_start
=
num_verify_tokens
-
to_free_num_slots_cur
to_free_slots_start
=
tl
.
sum
(
to_free_num_slots_all
)
copy_len
=
to_free_num_slots_cur
def
get_spec_adjust_token_coefficient
(
self
)
->
Tuple
[
int
,
int
]:
out_cache_loc_row
=
tl
.
load
(
return
self
.
num_tokens_per_batch
,
self
.
num_tokens_for_logprob_per_batch
out_cache_loc
+
bid
*
num_verify_tokens
+
out_cache_loc_start
+
offset
,
mask
=
offset
<
copy_len
,
)
tl
.
store
(
to_free_slots
+
to_free_slots_start
+
offset
,
out_cache_loc_row
,
mask
=
offset
<
copy_len
,
)
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
@
torch
.
compile
(
dynamic
=
True
)
if
batch
.
forward_mode
.
is_idle
():
def
get_src_tgt_cache_loc
(
return
seq_lens
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
accept_index
:
torch
.
Tensor
,
accept_length
:
torch
.
Tensor
,
draft_token_num
:
int
,
page_size
:
int
,
):
src_cache_loc
=
out_cache_loc
[
accept_index
]
tgt_cache_loc
=
torch
.
empty_like
(
src_cache_loc
)
extended_len
=
seq_lens
+
draft_token_num
keep_len
=
torch
.
minimum
(
(
seq_lens
+
accept_length
+
1
+
page_size
-
1
)
//
page_size
*
page_size
,
extended_len
,
)
to_free_num_slots
=
extended_len
-
keep_len
return
src_cache_loc
,
tgt_cache_loc
,
to_free_num_slots
@
triton
.
jit
def
filter_finished_cache_loc_kernel
(
out_cache_loc
,
tgt_cache_loc
,
accept_length
,
accept_length_filter
,
bs_upper
:
tl
.
constexpr
,
num_verify_tokens_upper
:
tl
.
constexpr
,
):
bid
=
tl
.
program_id
(
0
)
bs_offset
=
tl
.
arange
(
0
,
bs_upper
)
accept_length_all
=
tl
.
load
(
accept_length
+
bs_offset
,
mask
=
bs_offset
<
bid
)
old_start
=
tl
.
sum
(
accept_length_all
)
+
bid
accept_length_filter_all
=
tl
.
load
(
accept_length_filter
+
bs_offset
,
mask
=
bs_offset
<
bid
)
new_start
=
tl
.
sum
(
accept_length_filter_all
)
copy_len
=
tl
.
load
(
accept_length_filter
+
bid
)
# Prefill only generate 1 token.
copy_offset
=
tl
.
arange
(
0
,
num_verify_tokens_upper
)
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
value
=
tl
.
load
(
tgt_cache_loc
+
old_start
+
copy_offset
,
mask
=
copy_offset
<
copy_len
)
tl
.
store
(
out_cache_loc
+
new_start
+
copy_offset
,
value
,
mask
=
copy_offset
<
copy_len
)
pt
=
0
for
i
,
extend_len
in
enumerate
(
batch
.
extend_lens
):
input_ids
=
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
=
torch
.
cat
(
(
input_ids
[
1
:],
self
.
verified_id
[
i
].
reshape
(
1
))
)
pt
+=
extend_len
@
torch
.
compile
(
dynamic
=
True
)
@
classmethod
def
create_accept_length_filter
(
def
create_idle_input
(
accept_length
:
torch
.
Tensor
,
cls
,
unfinished_index_device
:
torch
.
Tensor
,
device
:
torch
.
device
,
seq_lens
:
torch
.
Tensor
,
hidden_size
:
int
,
):
dtype
:
torch
.
dtype
,
accept_length_filter
=
torch
.
zeros_like
(
accept_length
)
topk
:
int
,
accept_length_filter
[
unfinished_index_device
]
=
(
capture_hidden_mode
:
CaptureHiddenMode
,
accept_length
[
unfinished_index_device
]
+
1
):
)
return
cls
(
seq_lens
.
add_
(
accept_length
+
1
)
verified_id
=
torch
.
empty
((
0
,),
device
=
device
,
dtype
=
torch
.
int32
),
return
accept_length_filter
hidden_states
=
torch
.
empty
((
0
,
hidden_size
),
device
=
device
,
dtype
=
dtype
),
topk_p
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
float32
),
topk_index
=
torch
.
empty
((
0
,
topk
),
device
=
device
,
dtype
=
torch
.
int64
),
@
torch
.
compile
(
dynamic
=
True
)
capture_hidden_mode
=
capture_hidden_mode
,
def
select_top_k_tokens
(
accept_length
=
torch
.
empty
((
0
,),
device
=
device
,
dtype
=
torch
.
int32
),
i
:
int
,
accept_length_cpu
=
[],
topk_p
:
torch
.
Tensor
,
topk_index
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
,
topk
:
int
,
):
if
i
==
0
:
# The first step after extend
input_ids
=
topk_index
.
flatten
()
hidden_states
=
hidden_states
.
repeat_interleave
(
topk
,
dim
=
0
)
scores
=
topk_p
# shape: (b, topk)
tree_info
=
(
topk_p
.
unsqueeze
(
1
),
# shape: (b, 1, topk)
topk_index
,
# shape: (b, topk)
torch
.
arange
(
-
1
,
topk
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
.
unsqueeze
(
0
)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
)
else
:
# The later decode steps
expand_scores
=
torch
.
mul
(
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
topk
,
topk
)
)
# (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p
,
topk_cs_index
=
fast_topk
(
expand_scores
.
flatten
(
start_dim
=
1
),
topk
,
dim
=-
1
)
# (b, topk)
scores
=
topk_cs_p
# shape: (b, topk)
topk_index
=
topk_index
.
reshape
(
-
1
,
topk
**
2
)
input_ids
=
torch
.
gather
(
topk_index
,
index
=
topk_cs_index
,
dim
=
1
).
flatten
()
if
hidden_states
.
shape
[
0
]
>
0
:
selected_input_index
=
topk_cs_index
.
flatten
()
//
topk
+
torch
.
arange
(
0
,
hidden_states
.
shape
[
0
],
step
=
topk
,
device
=
"cuda"
).
repeat_interleave
(
topk
)
hidden_states
=
hidden_states
[
selected_input_index
,
:]
tree_info
=
(
expand_scores
,
# shape: (b, topk, topk)
topk_index
,
# shape: (b, topk * topk)
topk_cs_index
+
(
topk
**
2
*
(
i
-
1
)
+
topk
),
# shape: (b, topk)
)
)
return
input_ids
,
hidden_states
,
scores
,
tree_info
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
def
_generate_simulated_accept_index
(
speculative_num_steps
:
int
,
accept_index
,
):
predict
,
accept_length
,
if
batch
.
forward_mode
.
is_idle
():
bs
,
return
spec_steps
,
simulate_acc_len
:
float
=
SIMULATE_ACC_LEN
,
batch
.
input_ids
=
self
.
verified_id
simulate_acc_method
:
str
=
SIMULATE_ACC_METHOD
,
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
):
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
assert
simulate_acc_len
>
0.0
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
if
simulate_acc_method
==
"multinomial"
:
batch
.
return_logprob
=
False
simulated_values
=
torch
.
normal
(
batch
.
return_hidden_states
=
False
mean
=
simulate_acc_len
,
std
=
1.0
,
size
=
(
1
,),
device
=
"cpu"
,
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
+
1
)
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
elif
simulate_acc_method
==
"match-expected"
:
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len
=
max
(
1.0
,
min
(
spec_steps
+
1
,
simulate_acc_len
))
lower
=
int
(
simulate_acc_len
//
1
)
upper
=
lower
+
1
if
lower
<
spec_steps
+
1
else
lower
if
lower
==
upper
:
simulate_acc_len
=
lower
else
:
weight_upper
=
simulate_acc_len
-
lower
weight_lower
=
1.0
-
weight_upper
probs
=
torch
.
tensor
([
weight_lower
,
weight_upper
],
device
=
"cpu"
)
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
simulate_acc_len
=
lower
if
sampled_index
==
0
else
upper
else
:
raise
ValueError
(
f
"Invalid simulate_acc_method:
{
SIMULATE_ACC_METHOD
}
"
)
accept_indx_first_col
=
accept_index
[:,
0
].
view
(
-
1
,
1
)
sim_accept_index
=
torch
.
full
(
(
bs
,
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sim_accept_index
[:,
:
simulate_acc_len
]
=
accept_indx_first_col
+
torch
.
arange
(
simulate_acc_len
,
device
=
accept_index
.
device
)
accept_length
.
fill_
(
simulate_acc_len
-
1
)
predict
.
fill_
(
100
)
# some legit token id
return
sim_accept_index
def
traverse_tree
(
retrieve_next_token
:
torch
.
Tensor
,
retrieve_next_sibling
:
torch
.
Tensor
,
draft_tokens
:
torch
.
Tensor
,
grammar
:
BaseGrammarObject
,
allocate_token_bitmask
:
torch
.
Tensor
,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert
(
retrieve_next_token
.
shape
==
retrieve_next_sibling
.
shape
==
draft_tokens
.
shape
)
allocate_token_bitmask
.
fill_
(
0
)
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
self
.
accept_length
.
add_
(
1
)
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
def
dfs
(
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
curr
:
int
,
batch
.
input_ids
,
retrieve_next_token
:
torch
.
Tensor
,
batch
.
seq_lens
,
retrieve_next_sibling
:
torch
.
Tensor
,
self
.
accept_length
,
parent_pos
:
int
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
self
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
):
):
if
curr
==
0
:
bs
=
self
.
accept_length
.
numel
()
# the first token generated by the target model, and thus it is always
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# accepted from the previous iteration
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
accepted
=
True
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
else
:
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
parent_bitmask
=
allocate_token_bitmask
[
parent_pos
]
curr_token_id
=
draft_tokens
[
curr
]
# 32 boolean bitmask values are packed into 32-bit integers
accepted
=
(
parent_bitmask
[
curr_token_id
//
32
]
&
(
1
<<
(
curr_token_id
%
32
))
)
!=
0
if
accepted
:
if
curr
!=
0
:
# Accept the current token
grammar
.
accept_token
(
draft_tokens
[
curr
])
if
not
grammar
.
is_terminated
():
# Generate the bitmask for the current token
grammar
.
fill_vocab_mask
(
allocate_token_bitmask
,
curr
)
if
retrieve_next_token
[
curr
]
!=
-
1
:
# Visit the child node
dfs
(
retrieve_next_token
[
curr
],
retrieve_next_token
,
retrieve_next_sibling
,
curr
,
)
if
curr
!=
0
:
if
paged_kernel_lens_sum
is
None
:
# Rollback the current token
paged_kernel_lens_sum
=
cum_kv_seq_len
[
-
1
]
grammar
.
rollback
(
1
)
if
retrieve_next_sibling
[
curr
]
!=
-
1
:
# Visit the sibling node
dfs
(
retrieve_next_sibling
[
curr
],
retrieve_next_token
,
retrieve_next_sibling
,
parent_pos
,
)
dfs
(
0
,
retrieve_next_token
,
retrieve_next_sibling
,
-
1
)
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
generate_token_bitmask
(
reqs
:
List
[
Req
],
create_flashinfer_kv_indices_triton
[(
bs
,)](
verify_input
:
EagleVerifyInput
,
req_to_token
,
retrieve_next_token_cpu
:
torch
.
Tensor
,
req_pool_indices
,
retrieve_next_sibling_cpu
:
torch
.
Tensor
,
paged_kernel_lens
,
draft_tokens_cpu
:
torch
.
Tensor
,
cum_kv_seq_len
,
vocab_size
:
int
,
None
,
):
kv_indices
,
"""
req_to_token
.
size
(
1
),
Generate the logit mask for structured output.
)
Draft model's token can be either valid or invalid with respect to the grammar.
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
None
We need to perform DFS to
1. figure out which tokens are accepted by the grammar.
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
,
has_been_filtered
:
bool
=
True
):
2. if so, what is the corresponding logit mask.
if
has_been_filtered
:
"""
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
num_draft_tokens
=
draft_tokens_cpu
.
shape
[
-
1
]
if
len
(
new_indices
)
!=
len
(
self
.
topk_p
):
allocate_token_bitmask
=
None
assert
len
(
reqs
)
==
retrieve_next_token_cpu
.
shape
[
0
]
grammar
=
None
for
i
,
req
in
enumerate
(
reqs
):
if
req
.
grammar
is
not
None
:
if
allocate_token_bitmask
is
None
:
allocate_token_bitmask
=
req
.
grammar
.
allocate_vocab_mask
(
vocab_size
=
vocab_size
,
batch_size
=
draft_tokens_cpu
.
numel
(),
device
=
"cpu"
,
)
grammar
=
req
.
grammar
s
=
time
.
perf_counter
()
traverse_tree
(
retrieve_next_token_cpu
[
i
],
retrieve_next_sibling_cpu
[
i
],
draft_tokens_cpu
[
i
],
req
.
grammar
,
allocate_token_bitmask
[
i
*
num_draft_tokens
:
(
i
+
1
)
*
num_draft_tokens
],
)
tree_traverse_time
=
time
.
perf_counter
()
-
s
if
tree_traverse_time
>
TREE_TRAVERSE_TIME_THRESHOLD
:
logger
.
warning
(
logger
.
warning
(
f
"Bit mask generation took
{
tree_traverse_time
}
seconds with "
f
"length of new_indices:
{
len
(
new_indices
)
}
!= length of topk_p:
{
len
(
self
.
topk_p
)
}
, this should not happen"
f
"grammar:
{
req
.
grammar
}
"
)
)
self
.
topk_p
=
self
.
topk_p
[:
len
(
new_indices
)]
self
.
topk_index
=
self
.
topk_index
[:
len
(
new_indices
)]
self
.
hidden_states
=
self
.
hidden_states
[:
len
(
new_indices
)]
self
.
verified_id
=
self
.
verified_id
[:
len
(
new_indices
)]
else
:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self
.
topk_p
=
self
.
topk_p
[
new_indices
]
self
.
topk_index
=
self
.
topk_index
[
new_indices
]
self
.
hidden_states
=
self
.
hidden_states
[
new_indices
]
self
.
verified_id
=
self
.
verified_id
[
new_indices
]
def
merge_batch
(
self
,
spec_info
:
"EagleDraftInput"
):
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
spec_info
.
hidden_states
self
.
verified_id
=
spec_info
.
verified_id
self
.
topk_p
=
spec_info
.
topk_p
self
.
topk_index
=
spec_info
.
topk_index
return
if
spec_info
.
hidden_states
is
None
:
return
self
.
hidden_states
=
torch
.
cat
(
[
self
.
hidden_states
,
spec_info
.
hidden_states
],
axis
=
0
)
self
.
verified_id
=
torch
.
cat
([
self
.
verified_id
,
spec_info
.
verified_id
],
axis
=
0
)
self
.
topk_p
=
torch
.
cat
([
self
.
topk_p
,
spec_info
.
topk_p
])
self
.
topk_index
=
torch
.
cat
([
self
.
topk_index
,
spec_info
.
topk_index
])
verify_input
.
grammar
=
grammar
@
dataclass
return
allocate_token_bitmask
class
EagleVerifyOutput
:
# Draft input batch
draft_input
:
EagleDraftInput
# Logit outputs from target worker
logits_output
:
LogitsProcessorOutput
# Accepted token ids including the bonus token
verified_id
:
torch
.
Tensor
# Accepted token length per sequence in a batch in CPU.
accept_length_per_req_cpu
:
List
[
int
]
# Accepted indices from logits_output.next_token_logits
accepted_indices
:
torch
.
Tensor
python/sglang/srt/speculative/eagle_worker.py
View file @
73d4a5f8
...
@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
...
@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from
sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner
import
(
from
sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner
import
(
EAGLEDraftExtendCudaGraphRunner
,
EAGLEDraftExtendCudaGraphRunner
,
)
)
from
sglang.srt.speculative.eagle_
utils
import
(
from
sglang.srt.speculative.eagle_
info
import
(
EagleDraftInput
,
EagleDraftInput
,
EagleVerifyInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
EagleVerifyOutput
,
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
(
assign_draft_cache_locs
,
assign_draft_cache_locs
,
fast_topk
,
fast_topk
,
generate_token_bitmask
,
generate_token_bitmask
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
empty_context
,
empty_context
,
get_available_gpu_memory
,
get_available_gpu_memory
,
...
...
python/sglang/srt/speculative/ngram_utils.py
View file @
73d4a5f8
...
@@ -2,7 +2,7 @@ from __future__ import annotations
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
copy
import
copy
import
logging
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
import
torch
import
triton
import
triton
...
@@ -13,6 +13,7 @@ from dataclasses import dataclass
...
@@ -13,6 +13,7 @@ from dataclasses import dataclass
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
...
@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.eagle_utils
import
(
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
from
sglang.srt.speculative.spec_utils
import
(
TREE_SPEC_KERNEL_AVAILABLE
,
TREE_SPEC_KERNEL_AVAILABLE
,
assign_req_to_token_pool
,
assign_req_to_token_pool
,
create_flashinfer_kv_indices_triton
,
get_src_tgt_cache_loc
,
get_src_tgt_cache_loc
,
get_target_cache_loc
,
get_target_cache_loc
,
)
)
...
@@ -42,7 +43,7 @@ elif is_hip():
...
@@ -42,7 +43,7 @@ elif is_hip():
@
dataclass
@
dataclass
class
NgramVerifyInput
:
class
NgramVerifyInput
(
SpecInput
)
:
def
__init__
(
def
__init__
(
self
,
self
,
draft_token
:
torch
.
Tensor
,
draft_token
:
torch
.
Tensor
,
...
@@ -53,6 +54,7 @@ class NgramVerifyInput:
...
@@ -53,6 +54,7 @@ class NgramVerifyInput:
retrive_next_sibling
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
draft_token_num
:
int
,
draft_token_num
:
int
,
):
):
super
().
__init__
(
SpecInputType
.
NGRAM_VERIFY
)
self
.
draft_token
=
draft_token
self
.
draft_token
=
draft_token
self
.
custom_mask
=
tree_mask
self
.
custom_mask
=
tree_mask
self
.
positions
=
positions
self
.
positions
=
positions
...
@@ -62,6 +64,9 @@ class NgramVerifyInput:
...
@@ -62,6 +64,9 @@ class NgramVerifyInput:
self
.
draft_token_num
=
draft_token_num
self
.
draft_token_num
=
draft_token_num
self
.
device
=
self
.
custom_mask
.
device
self
.
device
=
self
.
custom_mask
.
device
def
get_spec_adjust_token_coefficient
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
draft_token_num
,
self
.
draft_token_num
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
if
batch
.
forward_mode
.
is_idle
():
if
batch
.
forward_mode
.
is_idle
():
return
return
...
...
python/sglang/srt/speculative/ngram_worker.py
View file @
73d4a5f8
import
logging
import
logging
import
os
from
typing
import
List
,
Optional
import
threading
import
time
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
...
@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.speculative.cpp_ngram.ngram_cache
import
NgramCache
from
sglang.srt.speculative.cpp_ngram.ngram_cache
import
NgramCache
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
broadcast_pyobj
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/speculative/spec_info.py
View file @
73d4a5f8
from
abc
import
ABC
,
abstractmethod
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
List
,
Tuple
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
class
SpeculativeAlgorithm
(
IntEnum
):
class
SpeculativeAlgorithm
(
IntEnum
):
...
@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
if
name
is
not
None
:
if
name
is
not
None
:
name
=
name
.
upper
()
name
=
name
.
upper
()
return
name_map
[
name
]
return
name_map
[
name
]
class
SpecInputType
(
IntEnum
):
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
EAGLE_DRAFT
=
auto
()
EAGLE_VERIFY
=
auto
()
NGRAM_VERIFY
=
auto
()
class
SpecInput
(
ABC
):
def
__init__
(
self
,
spec_input_type
:
SpecInputType
):
self
.
spec_input_type
=
spec_input_type
def
is_draft_input
(
self
)
->
bool
:
# FIXME: remove this function which is only used for assertion
# or use another variable name like `draft_input` to substitute `spec_info`
return
self
.
spec_input_type
==
SpecInputType
.
EAGLE_DRAFT
def
is_verify_input
(
self
)
->
bool
:
return
self
.
spec_input_type
in
{
SpecInputType
.
EAGLE_VERIFY
,
SpecInputType
.
NGRAM_VERIFY
,
}
@
abstractmethod
def
get_spec_adjust_token_coefficient
(
self
)
->
Tuple
[
int
,
int
]:
pass
def
get_spec_adjusted_global_num_tokens
(
self
,
forward_batch
:
ModelWorkerBatch
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
c1
,
c2
=
self
.
get_spec_adjust_token_coefficient
()
global_num_tokens
=
[
x
*
c1
for
x
in
forward_batch
.
global_num_tokens
]
global_num_tokens_for_logprob
=
[
x
*
c2
for
x
in
forward_batch
.
global_num_tokens_for_logprob
]
return
global_num_tokens
,
global_num_tokens_for_logprob
python/sglang/srt/speculative/spec_utils.py
0 → 100644
View file @
73d4a5f8
from
__future__
import
annotations
import
logging
import
os
import
time
from
typing
import
TYPE_CHECKING
,
List
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.environ
import
envs
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.utils
import
is_cuda
,
is_hip
if
is_cuda
():
from
sgl_kernel
import
fast_topk
elif
is_hip
():
from
sgl_kernel
import
fast_topk
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
logger
=
logging
.
getLogger
(
__name__
)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN
=
envs
.
SGLANG_SIMULATE_ACC_LEN
.
get
()
# turn off if < 0
SIMULATE_ACC_METHOD
=
envs
.
SGLANG_SIMULATE_ACC_METHOD
.
get
()
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE
=
"tree_speculative_sampling_target_only"
in
globals
()
@
triton
.
jit
def
create_extend_after_decode_spec_info
(
verified_id
,
seq_lens
,
accept_lens
,
positions
,
new_verified_id
,
bs_upper
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
offsets
=
tl
.
arange
(
0
,
bs_upper
)
seq_length
=
tl
.
load
(
seq_lens
+
pid
)
accept_length
=
tl
.
load
(
accept_lens
+
pid
)
accept_len_cumsum
=
tl
.
sum
(
tl
.
load
(
accept_lens
+
offsets
,
mask
=
offsets
<
pid
,
other
=
0
)
)
positions_ptr
=
positions
+
accept_len_cumsum
mask
=
offsets
<
accept_length
tl
.
store
(
positions_ptr
+
offsets
,
seq_length
-
accept_length
+
offsets
,
mask
)
accept_len_cumsum
+=
accept_length
-
1
verified_id_data
=
tl
.
load
(
verified_id
+
accept_len_cumsum
)
tl
.
store
(
new_verified_id
+
pid
,
verified_id_data
)
@
triton
.
jit
def
assign_req_to_token_pool
(
req_pool_indices
,
req_to_token
,
start_offset
,
end_offset
,
out_cache_loc
,
pool_len
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
32
pid
=
tl
.
program_id
(
axis
=
0
)
kv_start
=
tl
.
load
(
start_offset
+
pid
)
kv_end
=
tl
.
load
(
end_offset
+
pid
)
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
length_offset
=
tl
.
arange
(
0
,
bs_upper
)
start
=
tl
.
load
(
start_offset
+
length_offset
,
mask
=
length_offset
<
pid
,
other
=
0
)
end
=
tl
.
load
(
end_offset
+
length_offset
,
mask
=
length_offset
<
pid
,
other
=
0
)
out_offset
=
tl
.
sum
(
end
-
start
,
axis
=
0
)
out_cache_ptr
=
out_cache_loc
+
out_offset
save_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
kv_start
load_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
save_offset
<
kv_end
data
=
tl
.
load
(
out_cache_ptr
+
load_offset
,
mask
=
mask
)
tl
.
store
(
token_pool
+
save_offset
,
data
,
mask
=
mask
)
save_offset
+=
BLOCK_SIZE
load_offset
+=
BLOCK_SIZE
@
triton
.
jit
def
assign_draft_cache_locs
(
req_pool_indices
,
req_to_token
,
seq_lens
,
extend_lens
,
num_new_pages_per_topk
,
out_cache_loc
,
pool_len
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
speculative_num_steps
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
iter_upper
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
128
pid
=
tl
.
program_id
(
axis
=
0
)
if
page_size
==
1
or
topk
==
1
:
copy_len
=
topk
*
speculative_num_steps
out_cache_ptr
=
out_cache_loc
+
pid
*
topk
*
speculative_num_steps
else
:
bs_offset
=
tl
.
arange
(
0
,
bs_upper
)
copy_len
=
tl
.
load
(
extend_lens
+
pid
)
cum_copy_len
=
tl
.
sum
(
tl
.
load
(
extend_lens
+
bs_offset
,
mask
=
bs_offset
<
pid
))
out_cache_ptr
=
out_cache_loc
+
cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start
=
tl
.
load
(
seq_lens
+
pid
)
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
num_loop
=
tl
.
cdiv
(
copy_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
copy_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
copy_offset
<
copy_len
data
=
tl
.
load
(
out_cache_ptr
+
copy_offset
,
mask
=
mask
)
tl
.
store
(
token_pool
+
kv_start
+
copy_offset
,
data
,
mask
=
mask
)
if
page_size
==
1
or
topk
==
1
:
return
# Part 2: Copy the indices for the last partial page
prefix_len
=
tl
.
load
(
seq_lens
+
pid
)
last_page_len
=
prefix_len
%
page_size
offsets
=
tl
.
arange
(
0
,
page_size
)
mask
=
offsets
<
last_page_len
num_new_pages_per_topk_
=
tl
.
load
(
num_new_pages_per_topk
+
pid
)
prefix_base
=
token_pool
+
prefix_len
-
last_page_len
for
topk_id
in
range
(
topk
):
value
=
tl
.
load
(
prefix_base
+
offsets
,
mask
=
mask
)
tl
.
store
(
prefix_base
+
topk_id
*
num_new_pages_per_topk_
*
page_size
+
offsets
,
value
,
mask
=
mask
,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest
=
tl
.
arange
(
0
,
iter_upper
)
for
topk_id
in
range
(
topk
):
indices
=
tl
.
load
(
prefix_base
+
topk_id
*
num_new_pages_per_topk_
*
page_size
+
last_page_len
+
iter_offest
,
mask
=
iter_offest
<
speculative_num_steps
,
)
tl
.
store
(
out_cache_loc
+
pid
*
topk
*
speculative_num_steps
+
topk_id
*
speculative_num_steps
+
iter_offest
,
indices
,
mask
=
iter_offest
<
speculative_num_steps
,
)
@
triton
.
jit
def
generate_draft_decode_kv_indices
(
req_pool_indices
,
req_to_token
,
paged_kernel_lens
,
kv_indices
,
kv_indptr
,
positions
,
pool_len
:
tl
.
constexpr
,
kv_indices_stride
:
tl
.
constexpr
,
kv_indptr_stride
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
iter_upper
:
tl
.
constexpr
,
num_tokens_upper
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
128
iters
=
tl
.
program_id
(
axis
=
0
)
bid
=
tl
.
program_id
(
axis
=
1
)
topk_id
=
tl
.
program_id
(
axis
=
2
)
num_steps
=
tl
.
num_programs
(
axis
=
0
)
num_seqs
=
tl
.
num_programs
(
axis
=
1
)
topk
=
tl
.
num_programs
(
axis
=
2
)
kv_indices
+=
kv_indices_stride
*
iters
kv_indptr
+=
kv_indptr_stride
*
iters
iters
+=
1
load_offset
=
tl
.
arange
(
0
,
bs_upper
)
seq_lens
=
tl
.
load
(
paged_kernel_lens
+
load_offset
,
mask
=
load_offset
<
bid
,
other
=
0
)
seq_len
=
tl
.
load
(
paged_kernel_lens
+
bid
)
cum_seq_len
=
tl
.
sum
(
seq_lens
)
# Update kv_indices
kv_offset
=
cum_seq_len
*
topk
+
bid
*
iters
*
topk
+
topk_id
*
(
seq_len
+
iters
)
kv_ptr
=
kv_indices
+
kv_offset
token_pool_ptr
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
bid
)
*
pool_len
kv_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
seq_len
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
kv_offset
<
seq_len
data
=
tl
.
load
(
token_pool_ptr
+
kv_offset
,
mask
=
mask
)
tl
.
store
(
kv_ptr
+
kv_offset
,
data
,
mask
=
mask
)
kv_offset
+=
BLOCK_SIZE
extend_offset
=
tl
.
arange
(
0
,
iter_upper
)
if
page_size
==
1
or
topk
==
1
:
extend_data
=
tl
.
load
(
token_pool_ptr
+
seq_len
+
topk_id
*
num_steps
+
tl
.
arange
(
0
,
iter_upper
),
mask
=
extend_offset
<
iters
,
)
else
:
prefix_len
=
seq_len
last_page_len
=
prefix_len
%
page_size
num_new_pages_per_topk
=
(
last_page_len
+
num_steps
+
page_size
-
1
)
//
page_size
prefix_base
=
seq_len
//
page_size
*
page_size
start
=
(
prefix_base
+
topk_id
*
num_new_pages_per_topk
*
page_size
+
last_page_len
)
extend_data
=
tl
.
load
(
token_pool_ptr
+
start
+
extend_offset
,
mask
=
extend_offset
<
iters
,
)
tl
.
store
(
kv_ptr
+
seq_len
+
extend_offset
,
extend_data
,
mask
=
extend_offset
<
iters
)
# Update kv_indptr
bs_offset
=
tl
.
arange
(
0
,
num_tokens_upper
)
zid
=
bid
*
topk
+
topk_id
if
zid
==
0
:
zid
=
num_seqs
*
topk
positions
=
tl
.
load
(
positions
+
bs_offset
,
mask
=
bs_offset
<
zid
,
other
=
0
)
base
=
tl
.
sum
(
positions
)
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
@
triton
.
jit
def
align_evict_mask_to_page_size
(
seq_lens
,
evict_mask
,
page_size
:
tl
.
constexpr
,
num_draft_tokens
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
t_range
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
bid
=
tl
.
program_id
(
axis
=
0
)
seq_len
=
tl
.
load
(
seq_lens
+
bid
)
io_mask
=
t_range
<
num_draft_tokens
mask_row
=
tl
.
load
(
evict_mask
+
bid
*
num_draft_tokens
+
t_range
,
mask
=
io_mask
,
other
=
0
)
num_trues
=
tl
.
sum
(
mask_row
)
num_false
=
num_draft_tokens
-
num_trues
start
=
(
seq_len
+
num_false
-
1
)
//
page_size
*
page_size
-
seq_len
for
i
in
range
(
max
(
start
,
0
),
min
(
start
+
page_size
,
num_draft_tokens
)):
tl
.
store
(
evict_mask
+
bid
*
num_draft_tokens
+
i
,
False
)
@
triton
.
jit
def
get_target_cache_loc
(
tgt_cache_loc
,
to_free_slots
,
accept_length
,
to_free_num_slots
,
out_cache_loc
,
num_verify_tokens
:
tl
.
constexpr
,
num_verify_tokens_upper
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
):
bid
=
tl
.
program_id
(
axis
=
0
)
offset
=
tl
.
arange
(
0
,
num_verify_tokens_upper
)
bs_offset
=
tl
.
arange
(
0
,
bs_upper
)
# write the first part to tgt_cache_loc
accept_len_all
=
tl
.
load
(
accept_length
+
bs_offset
,
mask
=
bs_offset
<
bid
)
tgt_cache_loc_start
=
tl
.
sum
(
accept_len_all
)
+
bid
copy_len
=
tl
.
load
(
accept_length
+
bid
)
+
1
out_cache_loc_row
=
tl
.
load
(
out_cache_loc
+
bid
*
num_verify_tokens
+
offset
,
mask
=
offset
<
copy_len
)
tl
.
store
(
tgt_cache_loc
+
tgt_cache_loc_start
+
offset
,
out_cache_loc_row
,
mask
=
offset
<
copy_len
,
)
# write the second part to to_free_num_pages
to_free_num_slots_all
=
tl
.
load
(
to_free_num_slots
+
bs_offset
,
mask
=
bs_offset
<
bid
)
to_free_num_slots_cur
=
tl
.
load
(
to_free_num_slots
+
bid
)
out_cache_loc_start
=
num_verify_tokens
-
to_free_num_slots_cur
to_free_slots_start
=
tl
.
sum
(
to_free_num_slots_all
)
copy_len
=
to_free_num_slots_cur
out_cache_loc_row
=
tl
.
load
(
out_cache_loc
+
bid
*
num_verify_tokens
+
out_cache_loc_start
+
offset
,
mask
=
offset
<
copy_len
,
)
tl
.
store
(
to_free_slots
+
to_free_slots_start
+
offset
,
out_cache_loc_row
,
mask
=
offset
<
copy_len
,
)
@
torch
.
compile
(
dynamic
=
True
)
def
get_src_tgt_cache_loc
(
seq_lens
:
torch
.
Tensor
,
out_cache_loc
:
torch
.
Tensor
,
accept_index
:
torch
.
Tensor
,
accept_length
:
torch
.
Tensor
,
draft_token_num
:
int
,
page_size
:
int
,
):
src_cache_loc
=
out_cache_loc
[
accept_index
]
tgt_cache_loc
=
torch
.
empty_like
(
src_cache_loc
)
extended_len
=
seq_lens
+
draft_token_num
keep_len
=
torch
.
minimum
(
(
seq_lens
+
accept_length
+
1
+
page_size
-
1
)
//
page_size
*
page_size
,
extended_len
,
)
to_free_num_slots
=
extended_len
-
keep_len
return
src_cache_loc
,
tgt_cache_loc
,
to_free_num_slots
@
triton
.
jit
def
filter_finished_cache_loc_kernel
(
out_cache_loc
,
tgt_cache_loc
,
accept_length
,
accept_length_filter
,
bs_upper
:
tl
.
constexpr
,
num_verify_tokens_upper
:
tl
.
constexpr
,
):
bid
=
tl
.
program_id
(
0
)
bs_offset
=
tl
.
arange
(
0
,
bs_upper
)
accept_length_all
=
tl
.
load
(
accept_length
+
bs_offset
,
mask
=
bs_offset
<
bid
)
old_start
=
tl
.
sum
(
accept_length_all
)
+
bid
accept_length_filter_all
=
tl
.
load
(
accept_length_filter
+
bs_offset
,
mask
=
bs_offset
<
bid
)
new_start
=
tl
.
sum
(
accept_length_filter_all
)
copy_len
=
tl
.
load
(
accept_length_filter
+
bid
)
copy_offset
=
tl
.
arange
(
0
,
num_verify_tokens_upper
)
value
=
tl
.
load
(
tgt_cache_loc
+
old_start
+
copy_offset
,
mask
=
copy_offset
<
copy_len
)
tl
.
store
(
out_cache_loc
+
new_start
+
copy_offset
,
value
,
mask
=
copy_offset
<
copy_len
)
@
torch
.
compile
(
dynamic
=
True
)
def
create_accept_length_filter
(
accept_length
:
torch
.
Tensor
,
unfinished_index_device
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
):
accept_length_filter
=
torch
.
zeros_like
(
accept_length
)
accept_length_filter
[
unfinished_index_device
]
=
(
accept_length
[
unfinished_index_device
]
+
1
)
seq_lens
.
add_
(
accept_length
+
1
)
return
accept_length_filter
@
torch
.
compile
(
dynamic
=
True
)
def
select_top_k_tokens
(
i
:
int
,
topk_p
:
torch
.
Tensor
,
topk_index
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
,
topk
:
int
,
):
if
i
==
0
:
# The first step after extend
input_ids
=
topk_index
.
flatten
()
hidden_states
=
hidden_states
.
repeat_interleave
(
topk
,
dim
=
0
)
scores
=
topk_p
# shape: (b, topk)
tree_info
=
(
topk_p
.
unsqueeze
(
1
),
# shape: (b, 1, topk)
topk_index
,
# shape: (b, topk)
torch
.
arange
(
-
1
,
topk
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
.
unsqueeze
(
0
)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
)
else
:
# The later decode steps
expand_scores
=
torch
.
mul
(
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
topk
,
topk
)
)
# (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p
,
topk_cs_index
=
fast_topk
(
expand_scores
.
flatten
(
start_dim
=
1
),
topk
,
dim
=-
1
)
# (b, topk)
scores
=
topk_cs_p
# shape: (b, topk)
topk_index
=
topk_index
.
reshape
(
-
1
,
topk
**
2
)
input_ids
=
torch
.
gather
(
topk_index
,
index
=
topk_cs_index
,
dim
=
1
).
flatten
()
if
hidden_states
.
shape
[
0
]
>
0
:
selected_input_index
=
topk_cs_index
.
flatten
()
//
topk
+
torch
.
arange
(
0
,
hidden_states
.
shape
[
0
],
step
=
topk
,
device
=
"cuda"
).
repeat_interleave
(
topk
)
hidden_states
=
hidden_states
[
selected_input_index
,
:]
tree_info
=
(
expand_scores
,
# shape: (b, topk, topk)
topk_index
,
# shape: (b, topk * topk)
topk_cs_index
+
(
topk
**
2
*
(
i
-
1
)
+
topk
),
# shape: (b, topk)
)
return
input_ids
,
hidden_states
,
scores
,
tree_info
def
_generate_simulated_accept_index
(
accept_index
,
predict
,
accept_length
,
bs
,
spec_steps
,
simulate_acc_len
:
float
=
SIMULATE_ACC_LEN
,
simulate_acc_method
:
str
=
SIMULATE_ACC_METHOD
,
):
assert
simulate_acc_len
>
0.0
if
simulate_acc_method
==
"multinomial"
:
simulated_values
=
torch
.
normal
(
mean
=
simulate_acc_len
,
std
=
1.0
,
size
=
(
1
,),
device
=
"cpu"
,
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
+
1
)
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
elif
simulate_acc_method
==
"match-expected"
:
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len
=
max
(
1.0
,
min
(
spec_steps
+
1
,
simulate_acc_len
))
lower
=
int
(
simulate_acc_len
//
1
)
upper
=
lower
+
1
if
lower
<
spec_steps
+
1
else
lower
if
lower
==
upper
:
simulate_acc_len
=
lower
else
:
weight_upper
=
simulate_acc_len
-
lower
weight_lower
=
1.0
-
weight_upper
probs
=
torch
.
tensor
([
weight_lower
,
weight_upper
],
device
=
"cpu"
)
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
simulate_acc_len
=
lower
if
sampled_index
==
0
else
upper
else
:
raise
ValueError
(
f
"Invalid simulate_acc_method:
{
SIMULATE_ACC_METHOD
}
"
)
accept_indx_first_col
=
accept_index
[:,
0
].
view
(
-
1
,
1
)
sim_accept_index
=
torch
.
full
(
(
bs
,
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sim_accept_index
[:,
:
simulate_acc_len
]
=
accept_indx_first_col
+
torch
.
arange
(
simulate_acc_len
,
device
=
accept_index
.
device
)
accept_length
.
fill_
(
simulate_acc_len
-
1
)
predict
.
fill_
(
100
)
# some legit token id
return
sim_accept_index
def
traverse_tree
(
retrieve_next_token
:
torch
.
Tensor
,
retrieve_next_sibling
:
torch
.
Tensor
,
draft_tokens
:
torch
.
Tensor
,
grammar
:
BaseGrammarObject
,
allocate_token_bitmask
:
torch
.
Tensor
,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert
(
retrieve_next_token
.
shape
==
retrieve_next_sibling
.
shape
==
draft_tokens
.
shape
)
allocate_token_bitmask
.
fill_
(
0
)
def
dfs
(
curr
:
int
,
retrieve_next_token
:
torch
.
Tensor
,
retrieve_next_sibling
:
torch
.
Tensor
,
parent_pos
:
int
,
):
if
curr
==
0
:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted
=
True
else
:
parent_bitmask
=
allocate_token_bitmask
[
parent_pos
]
curr_token_id
=
draft_tokens
[
curr
]
# 32 boolean bitmask values are packed into 32-bit integers
accepted
=
(
parent_bitmask
[
curr_token_id
//
32
]
&
(
1
<<
(
curr_token_id
%
32
))
)
!=
0
if
accepted
:
if
curr
!=
0
:
# Accept the current token
grammar
.
accept_token
(
draft_tokens
[
curr
])
if
not
grammar
.
is_terminated
():
# Generate the bitmask for the current token
grammar
.
fill_vocab_mask
(
allocate_token_bitmask
,
curr
)
if
retrieve_next_token
[
curr
]
!=
-
1
:
# Visit the child node
dfs
(
retrieve_next_token
[
curr
],
retrieve_next_token
,
retrieve_next_sibling
,
curr
,
)
if
curr
!=
0
:
# Rollback the current token
grammar
.
rollback
(
1
)
if
retrieve_next_sibling
[
curr
]
!=
-
1
:
# Visit the sibling node
dfs
(
retrieve_next_sibling
[
curr
],
retrieve_next_token
,
retrieve_next_sibling
,
parent_pos
,
)
dfs
(
0
,
retrieve_next_token
,
retrieve_next_sibling
,
-
1
)
def
generate_token_bitmask
(
reqs
:
List
[
Req
],
verify_input
:
EagleVerifyInput
,
retrieve_next_token_cpu
:
torch
.
Tensor
,
retrieve_next_sibling_cpu
:
torch
.
Tensor
,
draft_tokens_cpu
:
torch
.
Tensor
,
vocab_size
:
int
,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to
1. figure out which tokens are accepted by the grammar.
2. if so, what is the corresponding logit mask.
"""
num_draft_tokens
=
draft_tokens_cpu
.
shape
[
-
1
]
allocate_token_bitmask
=
None
assert
len
(
reqs
)
==
retrieve_next_token_cpu
.
shape
[
0
]
grammar
=
None
for
i
,
req
in
enumerate
(
reqs
):
if
req
.
grammar
is
not
None
:
if
allocate_token_bitmask
is
None
:
allocate_token_bitmask
=
req
.
grammar
.
allocate_vocab_mask
(
vocab_size
=
vocab_size
,
batch_size
=
draft_tokens_cpu
.
numel
(),
device
=
"cpu"
,
)
grammar
=
req
.
grammar
s
=
time
.
perf_counter
()
traverse_tree
(
retrieve_next_token_cpu
[
i
],
retrieve_next_sibling_cpu
[
i
],
draft_tokens_cpu
[
i
],
req
.
grammar
,
allocate_token_bitmask
[
i
*
num_draft_tokens
:
(
i
+
1
)
*
num_draft_tokens
],
)
tree_traverse_time
=
time
.
perf_counter
()
-
s
if
tree_traverse_time
>
TREE_TRAVERSE_TIME_THRESHOLD
:
logger
.
warning
(
f
"Bit mask generation took
{
tree_traverse_time
}
seconds with "
f
"grammar:
{
req
.
grammar
}
"
)
verify_input
.
grammar
=
grammar
return
allocate_token_bitmask
python/sglang/srt/two_batch_overlap.py
View file @
73d4a5f8
...
@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
)
)
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
...
@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
def
get_token_num_per_seq
(
def
get_token_num_per_seq
(
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
):
):
if
forward_mode
.
is_target_verify
():
if
forward_mode
.
is_target_verify
():
return
spec_info
.
draft_token_num
return
spec_info
.
draft_token_num
...
@@ -273,7 +274,7 @@ def compute_split_token_index(
...
@@ -273,7 +274,7 @@ def compute_split_token_index(
def
compute_split_indices_for_cuda_graph_replay
(
def
compute_split_indices_for_cuda_graph_replay
(
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
cuda_graph_num_tokens
:
int
,
cuda_graph_num_tokens
:
int
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
forward_mode_for_tbo_split
=
(
forward_mode_for_tbo_split
=
(
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
...
@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
...
@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
bs
:
int
,
bs
:
int
,
num_token_non_padded
:
int
,
num_token_non_padded
:
int
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
token_num_per_seq
=
get_token_num_per_seq
(
token_num_per_seq
=
get_token_num_per_seq
(
forward_mode
=
forward_mode
,
spec_info
=
spec_info
forward_mode
=
forward_mode
,
spec_info
=
spec_info
...
...
test/srt/test_forward_split_prefill.py
View file @
73d4a5f8
...
@@ -7,7 +7,6 @@ or
...
@@ -7,7 +7,6 @@ or
python3 test_forward_split_prefill.py
python3 test_forward_split_prefill.py
"""
"""
import
time
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
...
@@ -16,7 +15,7 @@ import torch
...
@@ -16,7 +15,7 @@ import torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment