Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52a34d74
Unverified
Commit
52a34d74
authored
Mar 16, 2025
by
Ying Sheng
Committed by
GitHub
Mar 16, 2025
Browse files
Add greedy verification kernel (#4383)
parent
06d12b39
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
392 additions
and
151 deletions
+392
-151
sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/csrc/speculative/eagle_utils.cu
+129
-98
sgl-kernel/csrc/speculative/packbit.cu
sgl-kernel/csrc/speculative/packbit.cu
+47
-0
sgl-kernel/csrc/speculative/speculative_sampling.cu
sgl-kernel/csrc/speculative/speculative_sampling.cu
+9
-2
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+13
-5
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+11
-8
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+17
-12
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+6
-1
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+38
-20
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/tests/speculative/test_eagle_utils.py
sgl-kernel/tests/speculative/test_eagle_utils.py
+98
-0
sgl-kernel/tests/speculative/test_speculative_sampling.py
sgl-kernel/tests/speculative/test_speculative_sampling.py
+23
-5
No files found.
sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
52a34d74
...
...
@@ -17,6 +17,8 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "pytorch_extension_utils.h"
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
...
...
@@ -72,8 +74,8 @@ __global__ void build_tree_efficient(
}
if
(
parent_position
==
draft_token_num
)
{
printf
(
"
ERROR
: invalid eagle tree!!! Detected a token with no parent token selected.
Check the logprob. The token
"
"
will be dropped.
"
);
"
WARNING
: invalid eagle tree!!! Detected a token with no parent token selected. "
"
Please check if the logprob has nan. The token will be ignored to keep proceeding.
\n
"
);
continue
;
}
...
...
@@ -140,112 +142,141 @@ void build_tree_kernel_efficient(
int32_t
(
draft_token_num
));
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
if
(
tid
>=
draft_token_num
)
{
return
;
}
int
seq_tree_idx
=
draft_token_num
*
draft_token_num
*
bid
;
for
(
int
i
=
0
;
i
<
bid
;
i
++
)
{
seq_tree_idx
+=
verified_seq_len
[
i
]
*
draft_token_num
;
}
int
seq_len
=
verified_seq_len
[
bid
];
int
token_tree_idx
=
seq_tree_idx
+
(
seq_len
+
draft_token_num
)
*
tid
+
seq_len
+
1
;
for
(
int
i
=
0
;
i
<
draft_token_num
-
1
;
i
++
)
{
tree_mask
[
token_tree_idx
+
i
]
=
false
;
}
int
position
=
0
;
if
(
tid
==
0
)
{
positions
[
bid
*
draft_token_num
]
=
seq_len
;
retrive_index
[
bid
*
draft_token_num
*
(
depth
+
2
)]
=
bid
*
draft_token_num
;
return
;
}
template
<
typename
IdType
>
__global__
void
VerifyTreeGreedy
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
IdType
*
target_predict
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
)
{
uint32_t
bx
=
blockIdx
.
x
;
int
depends_order
[
10
];
IdType
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
uint32_t
num_accepted_tokens
=
0
;
IdType
cur_index
=
0
;
int
cur_position
=
tid
-
1
;
while
(
true
)
{
depends_order
[
position
]
=
cur_position
+
1
;
position
+=
1
;
tree_mask
[
token_tree_idx
+
cur_position
]
=
true
;
int
parent_tb_idx
=
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
/
topk
;
if
(
parent_tb_idx
==
0
)
{
break
;
}
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
target_token_id
=
target_predict
[
last_accepted_retrive_idx
];
int
token_idx
=
parent_list
[
bid
*
(
topk
*
(
depth
-
1
)
+
1
)
+
parent_tb_idx
];
for
(
cur_position
=
0
;
cur_position
<
draft_token_num
;
cur_position
++
)
{
if
(
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
==
token_idx
)
{
if
(
draft_token_id
==
target_token_id
)
{
// accept token
predicts
[
last_accepted_retrive_idx
]
=
target_token_id
;
++
num_accepted_tokens
;
accept_index
[
bx
*
num_speculative_tokens
+
num_accepted_tokens
]
=
draft_index
;
last_accepted_retrive_idx
=
draft_index
;
break
;
}
else
{
cur_index
=
retrive_next_sibling
[
bx
*
num_draft_tokens
+
cur_index
];
}
}
if
(
cur_position
==
draft_token_num
)
{
printf
(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped."
);
break
;
}
if
(
cur_index
==
-
1
)
break
;
}
positions
[
bid
*
draft_token_num
+
tid
]
=
position
+
seq_len
;
accept_token_num
[
bx
]
=
num_accepted_tokens
;
predicts
[
last_accepted_retrive_idx
]
=
target_predict
[
last_accepted_retrive_idx
];
}
int
is_leaf
=
0
;
for
(
int
i
=
1
;
i
<
draft_token_num
;
i
++
)
{
if
(
tree_mask
[
seq_tree_idx
+
i
*
(
draft_token_num
+
seq_len
)
+
seq_len
+
tid
])
{
is_leaf
++
;
}
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// target_predict: [bs, num_draft_tokens]
void
verify_tree_greedy
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
CHECK_INPUT
(
retrive_next_sibling
);
CHECK_INPUT
(
target_predict
);
auto
device
=
target_predict
.
device
();
CHECK_EQ
(
candidates
.
device
(),
device
);
CHECK_EQ
(
retrive_index
.
device
(),
device
);
CHECK_EQ
(
retrive_next_token
.
device
(),
device
);
CHECK_EQ
(
retrive_next_sibling
.
device
(),
device
);
CHECK_EQ
(
target_predict
.
device
(),
device
);
CHECK_DIM
(
1
,
predicts
);
CHECK_DIM
(
2
,
accept_index
);
CHECK_DIM
(
1
,
accept_token_num
);
CHECK_DIM
(
2
,
candidates
);
CHECK_DIM
(
2
,
retrive_index
);
CHECK_DIM
(
2
,
retrive_next_token
);
CHECK_DIM
(
2
,
retrive_next_sibling
);
CHECK_DIM
(
2
,
target_predict
);
unsigned
int
batch_size
=
candidates
.
size
(
0
);
unsigned
int
num_spec_step
=
accept_index
.
size
(
1
);
unsigned
int
num_draft_tokens
=
candidates
.
size
(
1
);
CHECK_EQ
(
batch_size
,
accept_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
accept_token_num
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_next_token
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_next_sibling
.
size
(
0
));
CHECK_EQ
(
batch_size
,
target_predict
.
size
(
0
));
CHECK_EQ
(
num_draft_tokens
,
retrive_index
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
retrive_next_token
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
retrive_next_sibling
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
target_predict
.
size
(
1
));
CHECK_EQ
(
batch_size
,
accept_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
accept_token_num
.
size
(
0
));
if
(
predicts
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'predicts' to be of type int (torch.int32)."
);
}
if
(
is_leaf
==
1
)
{
for
(
int
i
=
0
;
i
<
position
;
i
++
)
{
retrive_index
[(
bid
*
(
draft_token_num
)
+
tid
)
*
(
depth
+
2
)
+
position
-
i
]
=
depends_order
[
i
]
+
bid
*
draft_token_num
;
}
retrive_index
[(
bid
*
(
draft_token_num
)
+
tid
)
*
(
depth
+
2
)]
=
bid
*
draft_token_num
;
if
(
accept_index
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_index' to be of type int (torch.int32)."
);
}
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
}
if
(
candidates
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type int (torch.int32)."
);
}
if
(
retrive_index
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type int (torch.int32)."
);
}
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type int (torch.int32)."
);
}
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type int (torch.int32)."
);
}
if
(
target_predict
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'target_predict' to be of type int (torch.int32)."
);
}
}
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
dim3
grid
(
bs
);
dim3
block
(
draft_token_num
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
dim3
grid
(
batch_size
);
dim3
block
(
1
);
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
VerifyTreeGreedy
<
int
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
int
*>
(
target_predict
.
data_ptr
()),
batch_size
,
num_spec_step
,
num_draft_tokens
);
}
sgl-kernel/csrc/speculative/packbit.cu
0 → 100644
View file @
52a34d74
// This is only a pluggin used for flashinfer 0.1.6. The new version does not need it.
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/quantization.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
// bitorder = "little"
void
segment_packbits
(
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
cuda_stream
)
{
CHECK_INPUT
(
x
);
CHECK_INPUT
(
input_indptr
);
CHECK_INPUT
(
output_indptr
);
auto
device
=
x
.
device
();
CHECK_EQ
(
input_indptr
.
device
(),
device
);
CHECK_EQ
(
output_indptr
.
device
(),
device
);
CHECK_EQ
(
y
.
device
(),
device
);
unsigned
int
batch_size
=
input_indptr
.
size
(
0
)
-
1
;
CHECK_EQ
(
output_indptr
.
size
(
0
),
batch_size
+
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
quantization
::
SegmentPackBits
(
static_cast
<
bool
*>
(
x
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
y
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_indptr
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_indptr
.
data_ptr
()),
batch_size
,
quantization
::
BitOrder
::
kLittle
,
stream
);
}
sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
52a34d74
...
...
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"
#include "speculative_sampling.cuh"
...
...
@@ -40,7 +39,9 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
,
double
threshold_single
,
double
threshold_acc
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
...
...
@@ -112,6 +113,10 @@ void tree_speculative_sampling_target_only(
if
(
draft_probs
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'target_probs' to be of type float (torch.float32)."
);
}
CHECK_GE
(
threshold_single
,
0
);
CHECK_GE
(
1
,
threshold_single
);
CHECK_GE
(
threshold_acc
,
0
);
CHECK_GE
(
1
,
threshold_acc
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
...
...
@@ -129,6 +134,8 @@ void tree_speculative_sampling_target_only(
num_spec_step
,
num_draft_tokens
,
vocab_size
,
static_cast
<
float
>
(
threshold_single
),
static_cast
<
float
>
(
threshold_acc
),
deterministic
,
stream
);
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
52a34d74
...
...
@@ -49,7 +49,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
)
{
uint32_t
d
,
DType
threshold_single
,
DType
threshold_acc
)
{
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
...
...
@@ -70,9 +72,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
prob_acc
+=
target_probs
[
cur_prob_offset
+
draft_token_id
];
DType
target_prob_single
=
target_probs
[
cur_prob_offset
+
draft_token_id
];
prob_acc
+=
target_prob_single
;
if
(
coin
<
prob_acc
)
{
if
(
coin
<
=
prob_acc
/
threshold_acc
||
target_prob_single
>=
threshold_single
)
{
// accept token
prob_acc
=
0.
;
cur_prob_offset
=
(
bx
*
num_draft_tokens
+
cur_index
)
*
d
;
...
...
@@ -169,7 +172,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
DType
threshold_single
=
1
,
DType
threshold_acc
=
1
,
bool
deterministic
=
true
,
cudaStream_t
stream
=
0
)
{
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
...
...
@@ -177,6 +182,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
BLOCK_THREADS
);
float
capped_threshold_acc
=
fmaxf
(
threshold_acc
,
1e-9
f
);
void
*
args
[]
=
{
&
predicts
,
&
output_token_ids
,
...
...
@@ -191,7 +197,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
&
batch_size
,
&
num_speculative_tokens
,
&
num_draft_tokens
,
&
d
};
&
d
,
&
threshold_single
,
&
capped_threshold_acc
};
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
DISPATCH_DETERMINISTIC
(
deterministic
,
DETERMINISTIC
,
{
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
...
...
sgl-kernel/csrc/torch_extension.cc
View file @
52a34d74
...
...
@@ -129,21 +129,24 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()"
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
"retrive_next_sibling, "
"int topk, int depth, int draft_token_num) -> ()"
);
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
m
.
def
(
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel"
,
torch
::
kCUDA
,
&
build_tree_kernel
);
m
.
def
(
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
/*
* From FlashInfer
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
52a34d74
...
...
@@ -183,8 +183,8 @@ void topk_softmax(
* From csrc/speculative
*/
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
predicts
,
// mutable
at
::
Tensor
accept_index
,
// mutable
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
...
...
@@ -193,33 +193,38 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
double
threshold_single
=
1
,
double
threshold_acc
=
1
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
void
verify_tree_greedy
(
at
::
Tensor
predicts
,
// mutable
at
::
Tensor
accept_index
,
// mutable
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel
(
void
build_tree_kernel
_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
segment_packbits
(
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
cuda_stream
);
/*
* From FlashInfer
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
52a34d74
...
...
@@ -42,8 +42,13 @@ from sgl_kernel.sampling import (
top_p_sampling_from_probs
,
)
from
sgl_kernel.speculative
import
(
build_tree_kernel
,
build_tree_kernel_efficient
,
segment_packbits
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
)
from
sgl_kernel.version
import
__version__
build_tree_kernel
=
(
None
# TODO(ying): remove this after updating the sglang python code.
)
sgl-kernel/python/sgl_kernel/speculative.py
View file @
52a34d74
...
...
@@ -13,6 +13,8 @@ def tree_speculative_sampling_target_only(
uniform_samples
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
threshold_single
:
float
=
1.0
,
threshold_acc
:
float
=
1.0
,
deterministic
:
bool
=
True
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
(
...
...
@@ -26,58 +28,74 @@ def tree_speculative_sampling_target_only(
uniform_samples
,
target_probs
,
draft_probs
,
threshold_single
,
threshold_acc
,
deterministic
,
get_cuda_stream
(),
)
def
build_tree_kernel_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
verify_tree_greedy
(
predicts
:
torch
.
Tensor
,
# mutable
accept_index
:
torch
.
Tensor
,
# mutable
accept_token_num
:
torch
.
Tensor
,
# mutable
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
target_predict
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
(
predicts
,
accept_index
,
accept_token_num
,
candidates
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
target_predict
,
get_cuda_stream
(),
)
def
build_tree_kernel
(
def
build_tree_kernel
_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel
(
torch
.
ops
.
sgl_kernel
.
build_tree_kernel
_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
)
def
segment_packbits
(
x
:
torch
.
Tensor
,
input_indptr
:
torch
.
Tensor
,
output_indptr
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
segment_packbits
(
x
,
input_indptr
,
output_indptr
,
y
,
torch
.
cuda
.
current_stream
().
cuda_stream
,
)
sgl-kernel/setup.py
View file @
52a34d74
...
...
@@ -209,6 +209,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/speculative/packbit.cu"
,
"csrc/torch_extension.cc"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
...
...
sgl-kernel/tests/speculative/test_eagle_utils.py
0 → 100644
View file @
52a34d74
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
verify_tree_greedy
def
test_verify_tree_greedy
():
candidates
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
7
,
8
,
9
,
10
,
11
,
12
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_index
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_next_token
=
torch
.
tensor
(
[
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_next_sibling
=
torch
.
tensor
(
[
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
target_logits
=
torch
.
full
((
2
,
6
,
20
),
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
[
0
,
0
,
3
]
=
10
target_logits
[
0
,
3
,
4
]
=
10
target_logits
[
0
,
4
,
5
]
=
10
target_logits
[
1
,
0
,
11
]
=
10
target_logits
[
1
,
4
,
12
]
=
10
for
i
in
range
(
target_logits
.
shape
[
0
]):
for
j
in
range
(
target_logits
.
shape
[
1
]):
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
target_logits
[
i
][
j
][
18
]
=
10
print
(
f
"
{
target_logits
=
}
"
)
target_predict
=
torch
.
argmax
(
target_logits
,
dim
=-
1
).
to
(
torch
.
int32
)
predict_shape
=
(
12
,)
bs
=
candidates
.
shape
[
0
]
num_spec_step
=
4
num_draft_tokens
=
candidates
.
shape
[
1
]
predicts
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_index
=
torch
.
full
(
(
bs
,
num_spec_step
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_token_num
=
torch
.
full
((
bs
,),
0
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
print
(
f
"
{
candidates
=
}
"
)
print
(
f
"
{
retrive_index
=
}
"
)
print
(
f
"
{
retrive_next_token
=
}
"
)
print
(
f
"
{
retrive_next_sibling
=
}
"
)
print
(
f
"
{
target_predict
=
}
"
)
verify_tree_greedy
(
predicts
=
predicts
,
accept_index
=
accept_index
,
accept_token_num
=
accept_token_num
,
candidates
=
candidates
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
target_predict
=
target_predict
,
)
print
(
f
"
{
predicts
=
}
"
)
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
return
predicts
,
accept_index
,
accept_token_num
if
__name__
==
"__main__"
:
predicts
,
accept_index
,
accept_token_num
=
test_verify_tree_greedy
()
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
sgl-kernel/tests/test_speculative_sampling.py
→
sgl-kernel/tests/
speculative/
test_speculative_sampling.py
View file @
52a34d74
...
...
@@ -3,7 +3,10 @@ import torch.nn.functional as F
from
sgl_kernel
import
tree_speculative_sampling_target_only
def
test_tree_speculative_sampling_target_only
():
def
test_tree_speculative_sampling_target_only
(
threshold_single
=
1
,
threshold_acc
=
1
):
print
(
f
"
\n
============= run test:
{
threshold_single
=
}
{
threshold_acc
=
}
==============
\n
"
)
candidates
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
...
...
@@ -37,7 +40,7 @@ def test_tree_speculative_sampling_target_only():
device
=
"cuda"
,
)
target_logits
=
torch
.
zeros
((
2
,
6
,
20
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
=
torch
.
full
((
2
,
6
,
20
),
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
[
0
,
0
,
3
]
=
10
target_logits
[
0
,
3
,
4
]
=
10
target_logits
[
0
,
4
,
5
]
=
10
...
...
@@ -85,6 +88,8 @@ def test_tree_speculative_sampling_target_only():
uniform_samples
=
coins
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
threshold_single
,
threshold_acc
=
threshold_acc
,
deterministic
=
True
,
)
...
...
@@ -92,6 +97,13 @@ def test_tree_speculative_sampling_target_only():
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
return
predicts
,
accept_index
,
accept_token_num
if
__name__
==
"__main__"
:
predicts
,
accept_index
,
accept_token_num
=
(
test_tree_speculative_sampling_target_only
(
threshold_single
=
1
,
threshold_acc
=
1
)
)
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
...
...
@@ -99,6 +111,12 @@ def test_tree_speculative_sampling_target_only():
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
if
__name__
==
"__main__"
:
test_tree_speculative_sampling_target_only
()
predicts
,
accept_index
,
accept_token_num
=
(
test_tree_speculative_sampling_target_only
(
threshold_single
=
0
,
threshold_acc
=
0
)
)
assert
predicts
.
tolist
()
==
[
1
,
2
,
18
,
-
1
,
-
1
,
-
1
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
1
,
2
,
-
1
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
2
,
2
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment