Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f9905d59
Unverified
Commit
f9905d59
authored
Feb 07, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 07, 2025
Browse files
support speculative decoding kernel in sgl-kernel (#3373)
Co-authored-by:
Ying Sheng
<
sqy1415@gmail.com
>
parent
45c87e08
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1299 additions
and
133 deletions
+1299
-133
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+482
-102
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+67
-29
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+2
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+6
-0
sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu
sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu
+209
-0
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu
+120
-0
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh
+184
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+16
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+84
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+23
-0
sgl-kernel/tests/test_speculative_sampling.py
sgl-kernel/tests/test_speculative_sampling.py
+104
-0
sgl-kernel/version.py
sgl-kernel/version.py
+1
-1
No files found.
python/sglang/srt/speculative/build_eagle_tree.py
View file @
f9905d59
import
cutex
# NOTE: Please run this file to make sure the test cases are correct.
from
typing
import
List
import
torch
# parent_table [bs,topk*depth+)]
# selected_index [bs,draft_token_num-1)]
# verified_seq_len [bs]
# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token]
# positions [bs*draft_token]
# retrive_index [b, draft_token, depth+2]
kernels
=
cutex
.
SourceModule
(
"""
//cuda
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
Tensor<bool, 1> tree_mask, Tensor<long, 1> positions, Tensor<long, 3> retrive_index, int topk, int depth, int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num){
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for(int i=0; i<bid; i++){
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = seq_tree_idx + (seq_len+draft_token_num)*tid + seq_len + 1;
for(int i=0; i<draft_token_num-1; i++){
tree_mask[token_tree_idx+i] = false;
}
int position = 0;
if (tid==0){
positions[bid*draft_token_num] = seq_len;
retrive_index[bid][0][0] = bid * draft_token_num;
return;
}
int depends_order[10];
int cur_position = tid-1;
while(true){
depends_order[position] = cur_position+1;
position += 1;
tree_mask[token_tree_idx+cur_position] = true;
int parent_tb_idx = selected_index[bid][cur_position]/topk;
if(parent_tb_idx==0){
break;
}
int token_idx = parent_list[bid][parent_tb_idx];
for(cur_position=0; cur_position<draft_token_num;cur_position++){
if(selected_index[bid][cur_position]==token_idx){
break;
}
}
}
positions[bid*draft_token_num+tid] = position + seq_len;
int is_leaf = 0;
for(int i=1;i<draft_token_num;i++){
if(tree_mask[seq_tree_idx + i * (draft_token_num+seq_len) + seq_len + tid])
{
is_leaf ++;
}
}
if(is_leaf==1){
for(int i=0; i<position; i++){
retrive_index[bid][tid][position-i] = depends_order[i] + bid * draft_token_num;
}
retrive_index[bid][tid][0] = bid*draft_token_num;
}
}
//!cuda
"""
,
float_bits
=
16
,
# change to 16 to use half precision as `float` type in the above source code.
boundscheck
=
True
,
# turning on for debug and off for performance (to use full threads of a block), default is on.
)
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
build_tree_kernel_efficient_preprocess
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
num_verify_tokens
:
int
,
):
score_list
=
torch
.
cat
(
score_list
,
dim
=
1
).
flatten
(
1
)
# b, n, topk; n= 1 + (num_steps-1) * self.topk
ss_token_list
=
torch
.
cat
(
token_list
,
dim
=
1
)
# b, (self.topk + (num_steps-1) * self.topk)
top_scores
=
torch
.
topk
(
score_list
,
num_verify_tokens
-
1
,
dim
=-
1
)
top_scores_index
=
top_scores
.
indices
top_scores_index
=
torch
.
sort
(
top_scores_index
).
values
draft_tokens
=
torch
.
gather
(
ss_token_list
,
index
=
top_scores_index
,
dim
=
1
)
draft_tokens
=
torch
.
cat
((
verified_id
.
unsqueeze
(
1
),
draft_tokens
),
dim
=
1
).
flatten
()
parent_list
=
torch
.
cat
(
parents_list
[:
-
1
],
dim
=
1
)
return
parent_list
,
top_scores_index
,
draft_tokens
def
build_tree_kernel_efficient
(
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
):
parent_list
,
top_scores_index
,
draft_tokens
=
(
build_tree_kernel_efficient_preprocess
(
verified_id
,
score_list
,
token_list
,
parents_list
,
num_verify_tokens
,
)
)
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
bs
=
seq_lens
.
numel
()
device
=
seq_lens
.
device
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
# where each row indicates the attending pattern of each draft token
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
retrive_index
=
torch
.
full
(
(
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
retrive_next_token
=
torch
.
full
(
(
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
retrive_next_sibling
=
torch
.
full
(
(
bs
,
num_verify_tokens
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
# position: where each token belongs to
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
# then, positions = [7, 8, 8, 9]
positions
=
torch
.
empty
((
bs
*
num_verify_tokens
,),
device
=
device
,
dtype
=
torch
.
long
)
sgl_build_tree_kernel_efficient
(
parent_list
,
top_scores_index
,
seq_lens
.
to
(
torch
.
int32
),
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
spec_steps
,
num_verify_tokens
,
)
return
(
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
def
build_tree_kernel
(
parent_list
,
top_score_index
,
seq_lens
,
seq_lens_sum
,
topk
,
depth
,
draft_token
verified_id
:
torch
.
Tensor
,
score_list
:
List
[
torch
.
Tensor
],
token_list
:
List
[
torch
.
Tensor
],
parents_list
:
List
[
torch
.
Tensor
],
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
topk
:
int
,
spec_steps
:
int
,
num_verify_tokens
:
int
,
):
parent_list
,
top_scores_index
,
draft_tokens
=
(
build_tree_kernel_efficient_preprocess
(
verified_id
,
score_list
,
token_list
,
parents_list
,
num_verify_tokens
,
)
)
bs
=
seq_lens
.
numel
()
device
=
parent_list
.
device
device
=
seq_lens
.
device
tree_mask
=
torch
.
full
(
(
seq_lens_sum
*
draft_token
+
draft_token
*
draft_token
*
bs
,),
(
seq_lens_sum
*
num_verify_tokens
+
num_verify_tokens
*
num_verify_tokens
*
bs
,
),
True
,
device
=
device
,
)
retrive_index
=
torch
.
full
(
(
bs
,
draft
_token
,
depth
+
2
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
(
bs
,
num_verify
_token
s
,
spec_steps
+
2
),
-
1
,
device
=
device
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
draft
_token
,),
device
=
device
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_verify
_token
s
,),
device
=
device
,
dtype
=
torch
.
long
)
kernels
.
build_tree
(
sgl_
build_tree
_kernel
(
parent_list
,
top_score_index
,
top_score
s
_index
,
seq_lens
.
to
(
torch
.
int32
),
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
draft_token
,
grid
=
(
bs
,
1
,
1
),
block
=
(
64
,
1
,
1
),
spec_steps
,
num_verify_tokens
,
)
index
=
retrive_index
.
sum
(
dim
=-
1
)
!=
-
depth
-
2
index
=
retrive_index
.
sum
(
dim
=-
1
)
!=
-
spec_steps
-
2
cum_len
=
torch
.
cumsum
(
torch
.
sum
(
index
,
dim
=-
1
),
dim
=-
1
)
retrive_cum_len
=
torch
.
zeros
(
(
cum_len
.
numel
()
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
retrive_cum_len
[
1
:]
=
cum_len
# TODO: this indexing cause a synchronization, optimize this
retrive_index
=
retrive_index
[
index
]
return
tree_mask
,
positions
,
retrive_index
,
retrive_cum_len
return
tree_mask
,
positions
,
retrive_index
,
retrive_cum_len
,
draft_tokens
if
__name__
==
"__main__"
:
def
test_build_tree_kernel
():
def
findp
(
p_i
,
index
,
parent_list
):
pos
=
index
//
10
index_list
=
index
.
tolist
()
...
...
@@ -311,21 +362,21 @@ if __name__ == "__main__":
bs
=
verified_seq_len
.
shape
[
0
]
topk
=
10
depth
=
5
# depth <= 10
draft_token
=
64
num_
draft_token
=
64
tree_mask
=
torch
.
full
(
(
torch
.
sum
(
verified_seq_len
).
item
()
*
draft_token
+
draft_token
*
draft_token
*
bs
,
torch
.
sum
(
verified_seq_len
).
item
()
*
num_
draft_token
+
num_
draft_token
*
num_
draft_token
*
bs
,
),
True
,
).
cuda
()
retrive_index
=
torch
.
full
(
(
bs
,
draft_token
,
depth
+
2
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
long
(
bs
,
num_
draft_token
,
depth
+
2
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
draft_token
,),
device
=
"cuda"
,
dtype
=
torch
.
long
)
positions
=
torch
.
empty
((
bs
*
num_
draft_token
,),
device
=
"cuda"
,
dtype
=
torch
.
long
)
kernels
.
build_tree
(
sgl_
build_tree
_kernel
(
parent_list
.
unsqueeze
(
0
),
index
.
unsqueeze
(
0
),
verified_seq_len
,
...
...
@@ -334,16 +385,345 @@ if __name__ == "__main__":
retrive_index
,
topk
,
depth
,
draft_token
,
grid
=
(
bs
,
1
,
1
),
block
=
(
64
,
1
,
1
),
num_draft_token
,
)
retrive_index
=
retrive_index
[
retrive_index
.
sum
(
dim
=-
1
)
!=
-
depth
-
2
]
c_mask
,
c_positions
,
c_retive_index
=
create_mask
(
verified_seq_len
,
draft_token
,
index
,
parent_list
,
depth
verified_seq_len
,
num_
draft_token
,
index
,
parent_list
,
depth
)
assert
torch
.
allclose
(
tree_mask
,
c_mask
),
"tree mask has error."
assert
torch
.
allclose
(
positions
,
c_positions
),
"positions has error."
assert
torch
.
allclose
(
retrive_index
,
c_retive_index
),
"retrive_index has error."
def
test_build_tree_kernel_efficient
():
verified_id
=
torch
.
tensor
([
29974
,
13
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
score_list
=
[
torch
.
tensor
(
[
[[
7.1127e-01
,
2.8292e-01
,
2.2995e-03
,
1.7357e-03
]],
[[
9.7476e-01
,
2.2219e-02
,
6.5031e-04
,
1.3212e-04
]],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.9142e-01
,
1.2863e-02
,
1.6873e-03
,
1.1871e-03
],
[
2.4787e-01
,
1.8818e-02
,
1.4204e-02
,
9.2235e-04
],
[
2.2971e-03
,
1.6700e-06
,
1.8737e-07
,
8.3146e-08
],
[
1.2771e-03
,
2.4374e-04
,
1.7832e-04
,
1.1947e-05
],
],
[
[
8.4832e-02
,
6.6068e-02
,
5.8304e-02
,
5.7851e-02
],
[
2.3616e-03
,
1.1243e-03
,
5.4368e-04
,
2.7768e-04
],
[
2.5286e-04
,
1.5578e-04
,
2.8817e-05
,
1.2888e-05
],
[
1.2834e-04
,
2.5417e-06
,
1.1279e-06
,
1.6088e-08
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.6438e-01
,
2.6997e-02
,
2.4236e-05
,
4.0821e-06
],
[
2.4402e-01
,
2.8409e-03
,
5.0935e-04
,
2.9022e-04
],
[
1.6178e-02
,
2.0567e-03
,
4.5892e-04
,
3.0034e-05
],
[
1.3023e-02
,
5.0497e-04
,
3.6371e-04
,
8.7750e-05
],
],
[
[
2.3263e-02
,
2.0054e-02
,
9.3990e-03
,
2.7783e-03
],
[
6.4156e-02
,
5.5506e-04
,
1.0429e-04
,
9.7211e-05
],
[
4.9950e-02
,
5.0630e-03
,
9.0068e-04
,
3.3656e-04
],
[
7.5817e-03
,
8.5731e-04
,
6.9972e-04
,
6.0793e-04
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
[
6.6420e-01
,
1.0525e-04
,
6.5864e-05
,
1.2253e-06
],
[
1.3019e-01
,
1.0461e-01
,
5.2083e-03
,
1.6777e-03
],
[
2.0103e-02
,
6.7335e-03
,
1.2625e-04
,
1.0364e-05
],
[
1.5142e-02
,
7.0819e-04
,
9.6595e-05
,
8.7951e-05
],
],
[
[
5.8608e-02
,
1.8840e-03
,
7.8535e-04
,
4.4400e-04
],
[
1.2185e-02
,
2.0684e-03
,
1.7418e-03
,
1.4327e-03
],
[
6.2455e-03
,
6.1487e-03
,
2.6862e-03
,
1.8034e-03
],
[
1.8590e-03
,
1.6151e-03
,
1.2481e-03
,
3.6038e-04
],
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
),
]
token_list
=
[
torch
.
tensor
(
[[
29896
,
29906
,
29900
,
29945
],
[
13
,
2
,
29871
,
28956
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29889
,
29974
,
29945
,
29900
,
29974
,
29922
,
29930
,
29958
,
29889
,
29974
,
29930
,
29945
,
29974
,
29922
,
29930
,
29958
,
],
[
22550
,
4136
,
16492
,
8439
,
29871
,
2
,
3001
,
13
,
2
,
13
,
29906
,
29946
,
2
,
13
,
29871
,
259
,
],
],
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29946
,
29945
,
29953
,
29906
,
29896
,
29945
,
29900
,
29906
,
29896
,
29945
,
29906
,
29953
,
29896
,
29945
,
29906
,
29946
,
],
[
29871
,
2
,
29901
,
29889
,
29871
,
2
,
395
,
259
,
29901
,
29871
,
2
,
29889
,
3001
,
1234
,
7146
,
2186
,
],
],
device
=
"cuda"
,
),
torch
.
tensor
(
[
[
29946
,
29974
,
29945
,
29930
,
29889
,
29922
,
29974
,
29930
,
29974
,
29946
,
29930
,
29922
,
29889
,
29974
,
29945
,
29922
,
],
[
29941
,
29906
,
2
,
29946
,
29871
,
450
,
319
,
14990
,
29946
,
29941
,
2
,
29906
,
29871
,
2
,
3001
,
13
,
],
],
device
=
"cuda"
,
),
]
parents_list
=
[
torch
.
tensor
(
[[
-
1
,
0
,
1
,
2
,
3
],
[
-
1
,
0
,
1
,
2
,
3
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
([[
4
,
8
,
9
,
10
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
20
,
24
,
21
,
28
],
[
24
,
28
,
20
,
21
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
torch
.
tensor
(
[[
36
,
40
,
41
,
44
],
[
36
,
40
,
44
,
45
]],
dtype
=
torch
.
int64
,
device
=
"cuda"
),
]
seq_lens
=
torch
.
tensor
([
5
,
10
],
dtype
=
torch
.
int64
,
device
=
"cuda"
)
topk
=
4
depth
=
4
num_draft_token
=
8
tree_mask
,
position
,
retrive_index
,
retrive_cum_len
,
draft_tokens
=
(
build_tree_kernel
(
verified_id
=
verified_id
,
score_list
=
score_list
,
token_list
=
token_list
,
parents_list
=
parents_list
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
)
from
sglang.srt.utils
import
first_rank_print
first_rank_print
(
"=========== build tree kernel =========="
)
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_cum_len
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
0
,
2
,
4
,
6
,
-
1
,
-
1
],
[
0
,
1
,
3
,
5
,
7
,
-
1
],
[
8
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
10
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
12
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
13
,
-
1
,
-
1
,
-
1
],
[
8
,
9
,
11
,
14
,
15
,
-
1
],
]
assert
retrive_cum_len
.
tolist
()
==
[
0
,
3
,
8
]
assert
draft_tokens
.
tolist
()
==
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
]
(
tree_mask
,
position
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
draft_tokens
,
)
=
build_tree_kernel_efficient
(
verified_id
=
verified_id
,
score_list
=
score_list
,
token_list
=
token_list
,
parents_list
=
parents_list
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
torch
.
sum
(
seq_lens
).
item
(),
topk
=
topk
,
spec_steps
=
depth
,
num_verify_tokens
=
num_draft_token
,
)
first_rank_print
(
"=========== build tree kernel efficient =========="
)
# first_rank_print(f"{tree_mask=}", flush=True)
first_rank_print
(
f
"
{
position
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_index
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_next_token
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
retrive_next_sibling
=
}
"
,
flush
=
True
)
first_rank_print
(
f
"
{
draft_tokens
=
}
"
,
flush
=
True
)
assert
position
.
tolist
()
==
[
5
,
6
,
6
,
7
,
7
,
8
,
8
,
9
,
10
,
11
,
12
,
12
,
12
,
12
,
13
,
14
]
assert
retrive_index
.
tolist
()
==
[
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
]
assert
retrive_next_token
.
tolist
()
==
[
[
1
,
3
,
4
,
5
,
6
,
7
,
-
1
,
-
1
],
[
1
,
2
,
-
1
,
6
,
-
1
,
-
1
,
7
,
-
1
],
]
assert
retrive_next_sibling
.
tolist
()
==
[
[
-
1
,
2
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
3
,
4
,
5
,
-
1
,
-
1
,
-
1
],
]
assert
draft_tokens
.
tolist
()
==
[
29974
,
29896
,
29906
,
29889
,
29974
,
29946
,
29896
,
29946
,
13
,
13
,
22550
,
4136
,
16492
,
8439
,
29871
,
29941
,
]
if
__name__
==
"__main__"
:
test_build_tree_kernel_efficient
()
test_build_tree_kernel
()
python/sglang/srt/speculative/eagle_utils.py
View file @
f9905d59
...
...
@@ -258,39 +258,77 @@ class EagleVerifyInput:
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
self
.
custom_mask
def
verify
(
self
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
predict
=
torch
.
cat
(
[
predict
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)],
dim
=-
1
)
draft_token
=
torch
.
cat
(
[
self
.
draft_token
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)],
[
self
.
draft_token
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
dim
=-
1
,
)
target_predict
=
predict
[
self
.
retrive_index
]
candidates
=
draft_token
[
self
.
retrive_index
]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask
=
candidates
[:,
1
:]
==
target_predict
[:,
:
-
1
]
accept_mask
=
(
torch
.
cumprod
(
accept_mask
,
dim
=
1
)).
sum
(
dim
=
1
)
bs
=
self
.
retrive_cum_len
.
numel
()
-
1
max_draft_len
=
self
.
retrive_index
.
shape
[
-
1
]
accept_index
=
torch
.
full
(
(
bs
,
max_draft_len
),
-
1
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
extract_index
=
torch
.
full
((
bs
*
2
,),
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
eagle_verify_retrive
[(
bs
,)](
self
.
retrive_index
.
contiguous
(),
accept_mask
.
contiguous
(),
self
.
retrive_cum_len
,
accept_index
,
accept_length
,
extract_index
,
max_draft_len
,
self
.
draft_token_num
,
triton
.
next_power_of_2
(
max_draft_len
),
)
if
batch
.
sampling_info
.
is_all_greedy
:
# temp == 0
bs
=
self
.
retrive_cum_len
.
numel
()
-
1
predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
predict
=
torch
.
cat
(
[
predict
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
dim
=-
1
)
target_predict
=
predict
[
self
.
retrive_index
]
# logits = logits_output.next_token_logits[self.retrive_index]
# target_predict = torch.argmax(logits[:, :-1], dim=-1)
accept_mask
=
candidates
[:,
1
:]
==
target_predict
[:,
:
-
1
]
accept_mask
=
(
torch
.
cumprod
(
accept_mask
,
dim
=
1
)).
sum
(
dim
=
1
)
max_draft_len
=
self
.
retrive_index
.
shape
[
-
1
]
accept_index
=
torch
.
full
(
(
bs
,
max_draft_len
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int
,
device
=
"cuda"
)
extract_index
=
torch
.
full
((
bs
*
2
,),
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
eagle_verify_retrive
[(
bs
,)](
self
.
retrive_index
.
contiguous
(),
accept_mask
.
contiguous
(),
self
.
retrive_cum_len
,
accept_index
,
accept_length
,
extract_index
,
max_draft_len
,
self
.
draft_token_num
,
triton
.
next_power_of_2
(
max_draft_len
),
)
else
:
# temp > 0
bs
=
self
.
retrive_index
.
shape
[
0
]
predict_shape
=
list
(
logits_output
.
next_token_logits
.
shape
)[:
-
1
]
predict_shape
[
-
1
]
+=
1
target_logits
=
logits_output
.
next_token_logits
[
self
.
retrive_index
]
predict
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_index
=
torch
.
full
(
(
bs
,
self
.
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expanded_temperature
=
batch
.
sampling_info
.
temperatures
.
unsqueeze
(
1
)
target_probs
=
F
.
softmax
(
target_logits
/
expanded_temperature
,
dim
=-
1
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
coins
=
torch
.
rand_like
(
candidates
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
tree_speculative_sampling_target_only
(
predicts
=
predict
,
# mutable
accept_index
=
accept_index
,
# mutable
accept_token_num
=
accept_length
,
# mutable
candidates
=
candidates
.
to
(
torch
.
int32
),
retrive_index
=
self
.
retrive_index
.
to
(
torch
.
int32
),
retrive_next_token
=
self
.
retrive_next_token
.
to
(
torch
.
int32
),
retrive_next_sibling
=
self
.
retrive_next_sibling
.
to
(
torch
.
int32
),
uniform_samples
=
coins
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
global_server_args_dict
[
"speculative_accept_threshold_single"
],
threshold_acc
=
global_server_args_dict
[
"speculative_accept_threshold_acc"
],
deterministic
=
True
,
)
new_accept_index
=
[]
unfinished_index
=
[]
...
...
sgl-kernel/pyproject.toml
View file @
f9905d59
...
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name
=
"sgl-kernel"
version
=
"0.0.3.post
1
"
version
=
"0.0.3.post
2
"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
requires-python
=
">=3.9"
...
...
sgl-kernel/setup.py
View file @
f9905d59
...
...
@@ -99,6 +99,8 @@ sources = [
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu"
,
"src/sgl-kernel/csrc/eagle_utils.cu"
,
"src/sgl-kernel/csrc/speculative_sampling.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
f9905d59
...
...
@@ -10,6 +10,8 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
build_tree_kernel
,
build_tree_kernel_efficient
,
custom_dispose
,
custom_reduce
,
fp8_scaled_mm
,
...
...
@@ -31,6 +33,7 @@ from sgl_kernel.ops import (
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
tree_speculative_sampling_target_only
,
)
__all__
=
[
...
...
@@ -57,4 +60,7 @@ __all__ = [
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
"tree_speculative_sampling_target_only"
,
"build_tree_kernel_efficient"
,
"build_tree_kernel"
,
]
sgl-kernel/src/sgl-kernel/csrc/eagle_utils.cu
0 → 100644
View file @
f9905d59
/*
* Copyright (c) 2025 by SGLang team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_next_token
,
int64_t
*
retrive_next_sibling
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
if
(
tid
>=
draft_token_num
)
{
return
;
}
int
seq_tree_idx
=
draft_token_num
*
draft_token_num
*
bid
;
for
(
int
i
=
0
;
i
<
bid
;
i
++
)
{
seq_tree_idx
+=
verified_seq_len
[
i
]
*
draft_token_num
;
}
int
seq_len
=
verified_seq_len
[
bid
];
int
token_tree_idx
=
seq_tree_idx
+
(
seq_len
+
draft_token_num
)
*
tid
+
seq_len
+
1
;
for
(
int
i
=
0
;
i
<
draft_token_num
-
1
;
i
++
)
{
tree_mask
[
token_tree_idx
+
i
]
=
false
;
}
int
position
=
0
;
if
(
tid
==
0
)
{
positions
[
bid
*
draft_token_num
]
=
seq_len
;
int
retrive_index_offset
=
bid
*
draft_token_num
;
for
(
int
i
=
draft_token_num
-
1
;
i
>
0
;
--
i
)
{
int
current_token_idx
=
retrive_index_offset
+
i
;
retrive_index
[
bid
*
draft_token_num
+
i
]
=
current_token_idx
;
int
parent_tb_idx
=
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
i
-
1
]
/
topk
;
int
parent_position
=
0
;
if
(
parent_tb_idx
>
0
)
{
int
parent_token_idx
=
parent_list
[
bid
*
(
topk
*
(
depth
-
1
)
+
1
)
+
parent_tb_idx
];
for
(;
parent_position
<
draft_token_num
;
++
parent_position
)
{
if
(
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
parent_position
]
==
parent_token_idx
)
{
++
parent_position
;
break
;
}
}
}
if
(
parent_position
==
draft_token_num
)
{
printf
(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped."
);
continue
;
}
if
(
retrive_next_token
[
bid
*
draft_token_num
+
parent_position
]
==
-
1
)
{
retrive_next_token
[
bid
*
draft_token_num
+
parent_position
]
=
i
;
}
else
{
int
origin_next_token
=
retrive_next_token
[
bid
*
draft_token_num
+
parent_position
];
retrive_next_token
[
bid
*
draft_token_num
+
parent_position
]
=
i
;
retrive_next_sibling
[
bid
*
draft_token_num
+
i
]
=
origin_next_token
;
}
}
retrive_index
[
bid
*
draft_token_num
]
=
bid
*
draft_token_num
;
}
else
{
int
cur_position
=
tid
-
1
;
while
(
true
)
{
position
+=
1
;
tree_mask
[
token_tree_idx
+
cur_position
]
=
true
;
int
parent_tb_idx
=
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
/
topk
;
if
(
parent_tb_idx
==
0
)
{
break
;
}
int
token_idx
=
parent_list
[
bid
*
(
topk
*
(
depth
-
1
)
+
1
)
+
parent_tb_idx
];
for
(
cur_position
=
0
;
cur_position
<
draft_token_num
;
++
cur_position
)
{
if
(
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
==
token_idx
)
{
break
;
}
}
}
positions
[
bid
*
draft_token_num
+
tid
]
=
position
+
seq_len
;
}
}
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
dim3
grid
(
bs
);
dim3
block
(
draft_token_num
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
if
(
tid
>=
draft_token_num
)
{
return
;
}
int
seq_tree_idx
=
draft_token_num
*
draft_token_num
*
bid
;
for
(
int
i
=
0
;
i
<
bid
;
i
++
)
{
seq_tree_idx
+=
verified_seq_len
[
i
]
*
draft_token_num
;
}
int
seq_len
=
verified_seq_len
[
bid
];
int
token_tree_idx
=
seq_tree_idx
+
(
seq_len
+
draft_token_num
)
*
tid
+
seq_len
+
1
;
for
(
int
i
=
0
;
i
<
draft_token_num
-
1
;
i
++
)
{
tree_mask
[
token_tree_idx
+
i
]
=
false
;
}
int
position
=
0
;
if
(
tid
==
0
)
{
positions
[
bid
*
draft_token_num
]
=
seq_len
;
retrive_index
[
bid
*
draft_token_num
*
(
depth
+
2
)]
=
bid
*
draft_token_num
;
return
;
}
int
depends_order
[
10
];
int
cur_position
=
tid
-
1
;
while
(
true
)
{
depends_order
[
position
]
=
cur_position
+
1
;
position
+=
1
;
tree_mask
[
token_tree_idx
+
cur_position
]
=
true
;
int
parent_tb_idx
=
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
/
topk
;
if
(
parent_tb_idx
==
0
)
{
break
;
}
int
token_idx
=
parent_list
[
bid
*
(
topk
*
(
depth
-
1
)
+
1
)
+
parent_tb_idx
];
for
(
cur_position
=
0
;
cur_position
<
draft_token_num
;
cur_position
++
)
{
if
(
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
==
token_idx
)
{
break
;
}
}
if
(
cur_position
==
draft_token_num
)
{
printf
(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped."
);
break
;
}
}
positions
[
bid
*
draft_token_num
+
tid
]
=
position
+
seq_len
;
int
is_leaf
=
0
;
for
(
int
i
=
1
;
i
<
draft_token_num
;
i
++
)
{
if
(
tree_mask
[
seq_tree_idx
+
i
*
(
draft_token_num
+
seq_len
)
+
seq_len
+
tid
])
{
is_leaf
++
;
}
}
if
(
is_leaf
==
1
)
{
for
(
int
i
=
0
;
i
<
position
;
i
++
)
{
retrive_index
[(
bid
*
(
draft_token_num
)
+
tid
)
*
(
depth
+
2
)
+
position
-
i
]
=
depends_order
[
i
]
+
bid
*
draft_token_num
;
}
retrive_index
[(
bid
*
(
draft_token_num
)
+
tid
)
*
(
depth
+
2
)]
=
bid
*
draft_token_num
;
}
}
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
dim3
grid
(
bs
);
dim3
block
(
draft_token_num
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
}
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cu
0 → 100644
View file @
f9905d59
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <speculative_sampling.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size]
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
CHECK_INPUT
(
retrive_next_sibling
);
CHECK_INPUT
(
uniform_samples
);
CHECK_INPUT
(
target_probs
);
auto
device
=
target_probs
.
device
();
CHECK_EQ
(
candidates
.
device
(),
device
);
CHECK_EQ
(
retrive_index
.
device
(),
device
);
CHECK_EQ
(
retrive_next_token
.
device
(),
device
);
CHECK_EQ
(
retrive_next_sibling
.
device
(),
device
);
CHECK_EQ
(
uniform_samples
.
device
(),
device
);
CHECK_EQ
(
target_probs
.
device
(),
device
);
CHECK_DIM
(
1
,
predicts
);
CHECK_DIM
(
2
,
accept_index
);
CHECK_DIM
(
1
,
accept_token_num
);
CHECK_DIM
(
2
,
candidates
);
CHECK_DIM
(
2
,
retrive_index
);
CHECK_DIM
(
2
,
retrive_next_token
);
CHECK_DIM
(
2
,
retrive_next_sibling
);
CHECK_DIM
(
2
,
uniform_samples
);
CHECK_DIM
(
3
,
target_probs
);
CHECK_DIM
(
3
,
draft_probs
);
unsigned
int
batch_size
=
uniform_samples
.
size
(
0
);
unsigned
int
num_spec_step
=
accept_index
.
size
(
1
);
unsigned
int
num_draft_tokens
=
candidates
.
size
(
1
);
unsigned
int
vocab_size
=
target_probs
.
size
(
2
);
CHECK_EQ
(
batch_size
,
candidates
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_next_token
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_next_sibling
.
size
(
0
));
CHECK_EQ
(
batch_size
,
target_probs
.
size
(
0
));
CHECK_EQ
(
num_draft_tokens
,
retrive_index
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
retrive_next_token
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
retrive_next_sibling
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
uniform_samples
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
target_probs
.
size
(
1
));
CHECK_EQ
(
vocab_size
,
target_probs
.
size
(
2
));
CHECK_EQ
(
batch_size
,
accept_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
accept_token_num
.
size
(
0
));
if
(
predicts
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'predicts' to be of type int (torch.int32)."
);
}
if
(
accept_index
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_index' to be of type int (torch.int32)."
);
}
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
}
if
(
candidates
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type int (torch.int32)."
);
}
if
(
retrive_index
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type int (torch.int32)."
);
}
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type int (torch.int32)."
);
}
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type int (torch.int32)."
);
}
if
(
uniform_samples
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'uniform_samples' to be of type float (torch.float32)."
);
}
if
(
target_probs
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'target_probs' to be of type float (torch.float32)."
);
}
if
(
draft_probs
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'target_probs' to be of type float (torch.float32)."
);
}
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
batch_size
,
num_spec_step
,
num_draft_tokens
,
vocab_size
,
deterministic
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"TreeSpeculativeSamplingTargetOnly failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
sgl-kernel/src/sgl-kernel/csrc/speculative_sampling.cuh
0 → 100644
View file @
f9905d59
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2024-2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SPECULATIVE_SAMPLING_CUH_
#define SPECULATIVE_SAMPLING_CUH_
#include <assert.h>
#include <flashinfer/sampling.cuh>
namespace
flashinfer
{
namespace
sampling
{
using
namespace
cub
;
template
<
uint32_t
BLOCK_THREADS
,
BlockScanAlgorithm
SCAN_ALGORITHM
,
BlockReduceAlgorithm
REDUCE_ALGORITHM
,
uint32_t
VEC_SIZE
,
bool
DETERMINISTIC
,
typename
DType
,
typename
IdType
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
)
{
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
uint8_t
smem_sampling
[];
auto
&
temp_storage
=
reinterpret_cast
<
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>&>
(
smem_sampling
);
DType
prob_acc
=
0.0
;
uint32_t
cur_prob_offset
=
bx
*
num_draft_tokens
*
d
;
DType
coin
=
uniform_samples
[
bx
*
num_draft_tokens
];
IdType
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
uint32_t
num_accepted_tokens
=
0
;
IdType
cur_index
=
0
;
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
prob_acc
+=
target_probs
[
cur_prob_offset
+
draft_token_id
];
if
(
coin
<
prob_acc
)
{
// accept token
prob_acc
=
0.
;
cur_prob_offset
=
(
bx
*
num_draft_tokens
+
cur_index
)
*
d
;
coin
=
uniform_samples
[
bx
*
num_draft_tokens
+
cur_index
];
predicts
[
last_accepted_retrive_idx
]
=
draft_token_id
;
++
num_accepted_tokens
;
accept_index
[
bx
*
num_speculative_tokens
+
num_accepted_tokens
]
=
draft_index
;
last_accepted_retrive_idx
=
draft_index
;
break
;
}
else
{
// FIXME: leverage draft probs
draft_probs
[
cur_prob_offset
+
draft_token_id
]
=
target_probs
[
cur_prob_offset
+
draft_token_id
];
cur_index
=
retrive_next_sibling
[
bx
*
num_draft_tokens
+
cur_index
];
}
}
if
(
cur_index
==
-
1
)
break
;
}
accept_token_num
[
bx
]
=
num_accepted_tokens
;
// sample from relu(target_probs - draft_probs)
DType
sum_relu_q_minus_p
(
0
);
vec_t
<
DType
,
VEC_SIZE
>
q_vec
,
p_vec
;
DType
relu_q_minus_p
[
VEC_SIZE
];
for
(
uint32_t
i
=
0
;
i
<
ceil_div
(
d
,
BLOCK_THREADS
*
VEC_SIZE
);
++
i
)
{
q_vec
.
fill
(
DType
(
0
));
p_vec
.
fill
(
DType
(
0
));
if
((
i
*
BLOCK_THREADS
+
tx
)
*
VEC_SIZE
<
d
)
{
q_vec
.
load
(
target_probs
+
cur_prob_offset
+
i
*
BLOCK_THREADS
*
VEC_SIZE
+
tx
*
VEC_SIZE
);
if
(
num_accepted_tokens
!=
num_speculative_tokens
-
1
)
{
// there is no draft_probs for the bonus token
p_vec
.
load
(
draft_probs
+
cur_prob_offset
+
i
*
BLOCK_THREADS
*
VEC_SIZE
+
tx
*
VEC_SIZE
);
}
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
relu_q_minus_p
[
j
]
=
max
(
q_vec
[
j
]
-
p_vec
[
j
],
DType
(
0
));
}
sum_relu_q_minus_p
+=
BlockReduce
<
DType
,
BLOCK_THREADS
,
REDUCE_ALGORITHM
>
(
temp_storage
.
block_prim
.
reduce
)
.
Sum
<
VEC_SIZE
>
(
relu_q_minus_p
);
__syncthreads
();
}
if
(
tx
==
0
)
{
temp_storage
.
block_aggregate
.
value
=
sum_relu_q_minus_p
;
}
// init the first rejected token to (d - 1)
temp_storage
.
sampled_id
=
d
-
1
;
__syncthreads
();
sum_relu_q_minus_p
=
temp_storage
.
block_aggregate
.
value
;
DType
u
=
coin
*
sum_relu_q_minus_p
;
DType
aggregate_relu_q_minus_p
(
0
);
for
(
uint32_t
i
=
0
;
i
<
ceil_div
(
d
,
BLOCK_THREADS
*
VEC_SIZE
);
++
i
)
{
q_vec
.
fill
(
DType
(
0
));
p_vec
.
fill
(
DType
(
0
));
if
((
i
*
BLOCK_THREADS
+
tx
)
*
VEC_SIZE
<
d
)
{
q_vec
.
load
(
target_probs
+
cur_prob_offset
+
i
*
BLOCK_THREADS
*
VEC_SIZE
+
tx
*
VEC_SIZE
);
if
(
num_accepted_tokens
!=
num_speculative_tokens
-
1
)
{
// there is no draft_probs for the bonus token
p_vec
.
load
(
draft_probs
+
cur_prob_offset
+
i
*
BLOCK_THREADS
*
VEC_SIZE
+
tx
*
VEC_SIZE
);
}
}
vec_t
<
DType
,
VEC_SIZE
>
relu_q_minus_p_vec
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
relu_q_minus_p_vec
[
j
]
=
max
(
q_vec
[
j
]
-
p_vec
[
j
],
DType
(
0
));
}
DeviceSamplingFromProb
<
VEC_SIZE
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
,
DETERMINISTIC
,
DType
>
(
i
,
d
,
[
&
](
DType
x
)
{
return
x
>
0
;
},
u
,
relu_q_minus_p_vec
,
aggregate_relu_q_minus_p
,
&
temp_storage
);
if
(
aggregate_relu_q_minus_p
>
u
)
{
break
;
}
}
__syncthreads
();
// set the first rejected token
predicts
[
last_accepted_retrive_idx
]
=
temp_storage
.
sampled_id
;
// value at not used indices are undefined
}
template
<
typename
DType
,
typename
IdType
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
IdType
*
output_token_ids
,
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
target_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
cudaStream_t
stream
=
0
)
{
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
BLOCK_THREADS
);
void
*
args
[]
=
{
&
predicts
,
&
output_token_ids
,
&
output_accepted_token_num
,
&
candidates
,
&
retrive_index
,
&
retrive_next_token
,
&
retrive_next_sibling
,
&
uniform_samples
,
&
target_probs
,
&
draft_probs
,
&
batch_size
,
&
num_speculative_tokens
,
&
num_draft_tokens
,
&
d
};
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
DISPATCH_DETERMINISTIC
(
deterministic
,
DETERMINISTIC
,
{
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
,
VEC_SIZE
,
DETERMINISTIC
,
DType
,
IdType
>
;
FLASHINFER_CUDA_CALL
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
})});
return
cudaSuccess
;
}
}
// namespace sampling
}
// namespace flashinfer
#endif // SPECULATIVE_SAMPLING_CUH_
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
f9905d59
...
...
@@ -127,3 +127,19 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
);
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
f9905d59
...
...
@@ -495,3 +495,87 @@ def min_p_sampling_from_probs(
return
_min_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
)
def
tree_speculative_sampling_target_only
(
predicts
:
torch
.
Tensor
,
# mutable
accept_index
:
torch
.
Tensor
,
# mutable
accept_token_num
:
torch
.
Tensor
,
# mutable
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
deterministic
:
bool
=
True
,
)
->
None
:
with
predicts
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
tree_speculative_sampling_target_only
(
predicts
,
accept_index
,
accept_token_num
,
candidates
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
uniform_samples
,
target_probs
,
draft_probs
,
deterministic
,
_get_cuda_stream
(
device
),
)
def
build_tree_kernel_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
with
parent_list
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
)
def
build_tree_kernel
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
with
parent_list
.
device
as
device
:
torch
.
ops
.
sgl_kernels
.
build_tree_kernel
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
topk
,
depth
,
draft_token_num
,
)
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
f9905d59
...
...
@@ -130,6 +130,29 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
// tree spec decode
m
.
def
(
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"bool deterministic, int cuda_stream) -> ()"
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
// eagle build tree
m
.
def
(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
"retrive_next_sibling, "
"int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
// eagle build tree
m
.
def
(
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel"
,
torch
::
kCUDA
,
&
build_tree_kernel
);
}
REGISTER_EXTENSION
(
_kernels
)
sgl-kernel/tests/test_speculative_sampling.py
0 → 100644
View file @
f9905d59
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
tree_speculative_sampling_target_only
def
test_tree_speculative_sampling_target_only
():
candidates
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
7
,
8
,
9
,
10
,
11
,
12
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_index
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_next_token
=
torch
.
tensor
(
[
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_next_sibling
=
torch
.
tensor
(
[
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
target_logits
=
torch
.
zeros
((
2
,
6
,
20
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
[
0
,
0
,
3
]
=
10
target_logits
[
0
,
3
,
4
]
=
10
target_logits
[
0
,
4
,
5
]
=
10
target_logits
[
1
,
0
,
11
]
=
10
target_logits
[
1
,
4
,
12
]
=
10
for
i
in
range
(
target_logits
.
shape
[
0
]):
for
j
in
range
(
target_logits
.
shape
[
1
]):
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
target_logits
[
i
][
j
][
18
]
=
10
temperatures
=
torch
.
tensor
([
0.01
,
0.01
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
predict_shape
=
(
12
,)
bs
=
candidates
.
shape
[
0
]
num_spec_step
=
4
num_draft_tokens
=
candidates
.
shape
[
1
]
predicts
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_index
=
torch
.
full
(
(
bs
,
num_spec_step
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_token_num
=
torch
.
full
((
bs
,),
0
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
expanded_temperature
=
temperatures
.
unsqueeze
(
1
).
unsqueeze
(
1
)
target_probs
=
F
.
softmax
(
target_logits
/
expanded_temperature
,
dim
=-
1
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
coins
=
torch
.
rand
(
bs
,
num_draft_tokens
,
device
=
"cuda"
).
to
(
torch
.
float32
)
print
(
f
"
{
candidates
=
}
"
)
print
(
f
"
{
retrive_index
=
}
"
)
print
(
f
"
{
retrive_next_token
=
}
"
)
print
(
f
"
{
retrive_next_sibling
=
}
"
)
print
(
f
"
{
coins
=
}
"
)
tree_speculative_sampling_target_only
(
predicts
=
predicts
,
accept_index
=
accept_index
,
accept_token_num
=
accept_token_num
,
candidates
=
candidates
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
uniform_samples
=
coins
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
deterministic
=
True
,
)
print
(
f
"
{
predicts
=
}
"
)
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
if
__name__
==
"__main__"
:
test_tree_speculative_sampling_target_only
()
sgl-kernel/version.py
View file @
f9905d59
__version__
=
"0.0.3.post
1
"
__version__
=
"0.0.3.post
2
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment