Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cfceb83d
Unverified
Commit
cfceb83d
authored
Jun 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 16, 2025
Browse files
Fix sampling for speculative decoding & simplify kernels (#7207)
parent
b1286a11
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
124 additions
and
79 deletions
+124
-79
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+5
-2
sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/csrc/speculative/eagle_utils.cu
+32
-32
sgl-kernel/csrc/speculative/packbit.cu
sgl-kernel/csrc/speculative/packbit.cu
+7
-3
sgl-kernel/csrc/speculative/speculative_sampling.cu
sgl-kernel/csrc/speculative/speculative_sampling.cu
+23
-16
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+23
-15
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+7
-1
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+4
-0
sgl-kernel/python/sgl_kernel/top_k.py
sgl-kernel/python/sgl_kernel/top_k.py
+11
-0
sgl-kernel/tests/speculative/test_eagle_utils.py
sgl-kernel/tests/speculative/test_eagle_utils.py
+5
-6
sgl-kernel/tests/speculative/test_speculative_sampling.py
sgl-kernel/tests/speculative/test_speculative_sampling.py
+6
-4
No files found.
sgl-kernel/csrc/common_extension.cc
100755 → 100644
View file @
cfceb83d
...
@@ -201,13 +201,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -201,13 +201,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"shuffle_rows"
,
torch
::
kCUDA
,
&
shuffle_rows
);
m
.
impl
(
"shuffle_rows"
,
torch
::
kCUDA
,
&
shuffle_rows
);
m
.
def
(
"apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"
);
m
.
def
(
"apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"
);
m
.
impl
(
"apply_shuffle_mul_sum"
,
torch
::
kCUDA
,
&
apply_shuffle_mul_sum
);
m
.
impl
(
"apply_shuffle_mul_sum"
,
torch
::
kCUDA
,
&
apply_shuffle_mul_sum
);
/*
/*
* From csrc/speculative
* From csrc/speculative
*/
*/
m
.
def
(
m
.
def
(
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"Tensor uniform_samples, Tensor
uniform_samples_for_final_sampling, Tensor
target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()"
);
"bool deterministic, int cuda_stream) -> ()"
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
...
@@ -224,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -224,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"
);
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
m
.
def
(
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"
);
m
.
def
(
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int batch_size, "
"int cuda_stream) -> ()"
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
/*
/*
...
...
sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
cfceb83d
...
@@ -32,7 +32,7 @@
...
@@ -32,7 +32,7 @@
__global__
void
build_tree_efficient
(
__global__
void
build_tree_efficient
(
int64_t
*
parent_list
,
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int64_t
*
selected_index
,
int
32
_t
*
verified_seq_len
,
int
64
_t
*
verified_seq_len
,
bool
*
tree_mask
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int64_t
*
retrive_index
,
...
@@ -135,7 +135,7 @@ void build_tree_kernel_efficient(
...
@@ -135,7 +135,7 @@ void build_tree_kernel_efficient(
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
build_tree_efficient
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int
32
_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
int
64
_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
...
@@ -146,32 +146,32 @@ void build_tree_kernel_efficient(
...
@@ -146,32 +146,32 @@ void build_tree_kernel_efficient(
int32_t
(
draft_token_num
));
int32_t
(
draft_token_num
));
}
}
template
<
typename
IdType
>
template
<
typename
IdType
,
typename
IdType2
>
__global__
void
VerifyTreeGreedy
(
__global__
void
VerifyTreeGreedy
(
IdType
*
predicts
,
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
2
*
candidates
,
IdType
*
retrive_index
,
IdType
2
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
2
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
IdType
2
*
retrive_next_sibling
,
IdType
*
target_predict
,
IdType
2
*
target_predict
,
uint32_t
batch_size
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
)
{
uint32_t
num_draft_tokens
)
{
uint32_t
bx
=
blockIdx
.
x
;
uint32_t
bx
=
blockIdx
.
x
;
IdType
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
IdType
2
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
uint32_t
num_accepted_tokens
=
0
;
uint32_t
num_accepted_tokens
=
0
;
IdType
cur_index
=
0
;
IdType
2
cur_index
=
0
;
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
while
(
cur_index
!=
-
1
)
{
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
2
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
2
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
target_token_id
=
target_predict
[
last_accepted_retrive_idx
];
IdType
2
target_token_id
=
target_predict
[
last_accepted_retrive_idx
];
if
(
draft_token_id
==
target_token_id
)
{
if
(
draft_token_id
==
target_token_id
)
{
// accept token
// accept token
...
@@ -251,35 +251,35 @@ void verify_tree_greedy(
...
@@ -251,35 +251,35 @@ void verify_tree_greedy(
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
}
}
if
(
candidates
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
candidates
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type
long
(torch.int
64
)."
);
}
}
if
(
retrive_index
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
retrive_index
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type
long
(torch.int
64
)."
);
}
}
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type
long
(torch.int
64
)."
);
}
}
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type
long
(torch.int
64
)."
);
}
}
if
(
target_predict
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
target_predict
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'target_predict' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'target_predict' to be of type
long
(torch.int
64
)."
);
}
}
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
dim3
grid
(
batch_size
);
dim3
grid
(
batch_size
);
dim3
block
(
1
);
dim3
block
(
1
);
VerifyTreeGreedy
<
int
><<<
grid
,
block
,
0
,
stream
>>>
(
VerifyTreeGreedy
<
int
32_t
,
int64_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
32_t
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
32_t
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
32_t
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
int
*>
(
target_predict
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
target_predict
.
data_ptr
()),
batch_size
,
batch_size
,
num_spec_step
,
num_spec_step
,
num_draft_tokens
);
num_draft_tokens
);
...
...
sgl-kernel/csrc/speculative/packbit.cu
View file @
cfceb83d
...
@@ -24,7 +24,12 @@ using namespace flashinfer;
...
@@ -24,7 +24,12 @@ using namespace flashinfer;
// bitorder = "little"
// bitorder = "little"
void
segment_packbits
(
void
segment_packbits
(
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
cuda_stream
)
{
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
batch_size
,
int64_t
cuda_stream
)
{
CHECK_INPUT
(
x
);
CHECK_INPUT
(
x
);
CHECK_INPUT
(
input_indptr
);
CHECK_INPUT
(
input_indptr
);
CHECK_INPUT
(
output_indptr
);
CHECK_INPUT
(
output_indptr
);
...
@@ -32,8 +37,7 @@ void segment_packbits(
...
@@ -32,8 +37,7 @@ void segment_packbits(
CHECK_EQ
(
input_indptr
.
device
(),
device
);
CHECK_EQ
(
input_indptr
.
device
(),
device
);
CHECK_EQ
(
output_indptr
.
device
(),
device
);
CHECK_EQ
(
output_indptr
.
device
(),
device
);
CHECK_EQ
(
y
.
device
(),
device
);
CHECK_EQ
(
y
.
device
(),
device
);
unsigned
int
batch_size
=
input_indptr
.
size
(
0
)
-
1
;
CHECK_GE
(
output_indptr
.
size
(
0
),
batch_size
+
1
);
CHECK_EQ
(
output_indptr
.
size
(
0
),
batch_size
+
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
quantization
::
SegmentPackBits
(
cudaError_t
status
=
quantization
::
SegmentPackBits
(
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
cfceb83d
...
@@ -37,6 +37,7 @@ void tree_speculative_sampling_target_only(
...
@@ -37,6 +37,7 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
uniform_samples_for_final_sampling
,
at
::
Tensor
target_probs
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
at
::
Tensor
draft_probs
,
double
threshold_single
,
double
threshold_single
,
...
@@ -48,6 +49,7 @@ void tree_speculative_sampling_target_only(
...
@@ -48,6 +49,7 @@ void tree_speculative_sampling_target_only(
CHECK_INPUT
(
retrive_next_token
);
CHECK_INPUT
(
retrive_next_token
);
CHECK_INPUT
(
retrive_next_sibling
);
CHECK_INPUT
(
retrive_next_sibling
);
CHECK_INPUT
(
uniform_samples
);
CHECK_INPUT
(
uniform_samples
);
CHECK_INPUT
(
uniform_samples_for_final_sampling
);
CHECK_INPUT
(
target_probs
);
CHECK_INPUT
(
target_probs
);
auto
device
=
target_probs
.
device
();
auto
device
=
target_probs
.
device
();
CHECK_EQ
(
candidates
.
device
(),
device
);
CHECK_EQ
(
candidates
.
device
(),
device
);
...
@@ -55,6 +57,7 @@ void tree_speculative_sampling_target_only(
...
@@ -55,6 +57,7 @@ void tree_speculative_sampling_target_only(
CHECK_EQ
(
retrive_next_token
.
device
(),
device
);
CHECK_EQ
(
retrive_next_token
.
device
(),
device
);
CHECK_EQ
(
retrive_next_sibling
.
device
(),
device
);
CHECK_EQ
(
retrive_next_sibling
.
device
(),
device
);
CHECK_EQ
(
uniform_samples
.
device
(),
device
);
CHECK_EQ
(
uniform_samples
.
device
(),
device
);
CHECK_EQ
(
uniform_samples_for_final_sampling
.
device
(),
device
);
CHECK_EQ
(
target_probs
.
device
(),
device
);
CHECK_EQ
(
target_probs
.
device
(),
device
);
CHECK_DIM
(
1
,
predicts
);
CHECK_DIM
(
1
,
predicts
);
CHECK_DIM
(
2
,
accept_index
);
CHECK_DIM
(
2
,
accept_index
);
...
@@ -92,21 +95,24 @@ void tree_speculative_sampling_target_only(
...
@@ -92,21 +95,24 @@ void tree_speculative_sampling_target_only(
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
}
}
if
(
candidates
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
candidates
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type
long
(torch.int
64
)."
);
}
}
if
(
retrive_index
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
retrive_index
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type
long
(torch.int
64
)."
);
}
}
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type
long
(torch.int
64
)."
);
}
}
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
k
Int
)
{
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
k
Long
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type
int
(torch.int
32
)."
);
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type
long
(torch.int
64
)."
);
}
}
if
(
uniform_samples
.
scalar_type
()
!=
at
::
kFloat
)
{
if
(
uniform_samples
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'uniform_samples' to be of type float (torch.float32)."
);
throw
std
::
runtime_error
(
"Expected 'uniform_samples' to be of type float (torch.float32)."
);
}
}
if
(
uniform_samples_for_final_sampling
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'uniform_samples_for_final_sampling' to be of type float (torch.float32)."
);
}
if
(
target_probs
.
scalar_type
()
!=
at
::
kFloat
)
{
if
(
target_probs
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'target_probs' to be of type float (torch.float32)."
);
throw
std
::
runtime_error
(
"Expected 'target_probs' to be of type float (torch.float32)."
);
}
}
...
@@ -119,15 +125,16 @@ void tree_speculative_sampling_target_only(
...
@@ -119,15 +125,16 @@ void tree_speculative_sampling_target_only(
CHECK_GE
(
1
,
threshold_acc
);
CHECK_GE
(
1
,
threshold_acc
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
32_t
,
int64_t
>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
32_t
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
32_t
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
32_t
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
int
64_t
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples
.
data_ptr
()),
static_cast
<
float
*>
(
uniform_samples_for_final_sampling
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
target_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
static_cast
<
float
*>
(
draft_probs
.
data_ptr
()),
batch_size
,
batch_size
,
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
cfceb83d
...
@@ -34,16 +34,18 @@ template <
...
@@ -34,16 +34,18 @@ template <
uint32_t
VEC_SIZE
,
uint32_t
VEC_SIZE
,
bool
DETERMINISTIC
,
bool
DETERMINISTIC
,
typename
DType
,
typename
DType
,
typename
IdType
>
typename
IdType
,
typename
IdType2
>
__global__
void
TreeSpeculativeSamplingTargetOnly
(
__global__
void
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
// mutable
IdType
*
predicts
,
// mutable
IdType
*
accept_index
,
// mutable
IdType
*
accept_index
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
2
*
candidates
,
IdType
*
retrive_index
,
IdType
2
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
2
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
IdType
2
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
uniform_samples
,
DType
*
uniform_samples_for_final_sampling
,
DType
*
target_probs
,
DType
*
target_probs
,
DType
*
draft_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
batch_size
,
...
@@ -62,16 +64,16 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
...
@@ -62,16 +64,16 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
DType
prob_acc
=
0.0
;
DType
prob_acc
=
0.0
;
uint32_t
cur_prob_offset
=
bx
*
num_draft_tokens
*
d
;
uint32_t
cur_prob_offset
=
bx
*
num_draft_tokens
*
d
;
DType
coin
=
uniform_samples
[
bx
*
num_draft_tokens
];
DType
coin
=
uniform_samples
[
bx
*
num_draft_tokens
];
IdType
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
IdType
2
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
uint32_t
num_accepted_tokens
=
0
;
uint32_t
num_accepted_tokens
=
0
;
IdType
cur_index
=
0
;
IdType
2
cur_index
=
0
;
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
while
(
cur_index
!=
-
1
)
{
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
2
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
2
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
DType
target_prob_single
=
target_probs
[
cur_prob_offset
+
draft_token_id
];
DType
target_prob_single
=
target_probs
[
cur_prob_offset
+
draft_token_id
];
prob_acc
+=
target_prob_single
;
prob_acc
+=
target_prob_single
;
...
@@ -95,6 +97,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
...
@@ -95,6 +97,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
}
}
accept_token_num
[
bx
]
=
num_accepted_tokens
;
accept_token_num
[
bx
]
=
num_accepted_tokens
;
// we need a different coin for the final sampling
coin
=
uniform_samples_for_final_sampling
[
bx
];
// sample from relu(target_probs - draft_probs)
// sample from relu(target_probs - draft_probs)
DType
sum_relu_q_minus_p
(
0
);
DType
sum_relu_q_minus_p
(
0
);
vec_t
<
DType
,
VEC_SIZE
>
q_vec
,
p_vec
;
vec_t
<
DType
,
VEC_SIZE
>
q_vec
,
p_vec
;
...
@@ -156,16 +161,17 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
...
@@ -156,16 +161,17 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
// value at not used indices are undefined
// value at not used indices are undefined
}
}
template
<
typename
DType
,
typename
IdType
>
template
<
typename
DType
,
typename
IdType
,
typename
IdType2
>
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
cudaError_t
TreeSpeculativeSamplingTargetOnly
(
IdType
*
predicts
,
// mutable
IdType
*
predicts
,
// mutable
IdType
*
output_token_ids
,
// mutable
IdType
*
output_token_ids
,
// mutable
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
output_accepted_token_num
,
// mutable
IdType
*
candidates
,
IdType
2
*
candidates
,
IdType
*
retrive_index
,
IdType
2
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
2
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
IdType
2
*
retrive_next_sibling
,
DType
*
uniform_samples
,
DType
*
uniform_samples
,
DType
*
uniform_samples_for_final_sampling
,
DType
*
target_probs
,
DType
*
target_probs
,
DType
*
draft_probs
,
DType
*
draft_probs
,
uint32_t
batch_size
,
uint32_t
batch_size
,
...
@@ -192,6 +198,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
...
@@ -192,6 +198,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
&
retrive_next_token
,
&
retrive_next_token
,
&
retrive_next_sibling
,
&
retrive_next_sibling
,
&
uniform_samples
,
&
uniform_samples
,
&
uniform_samples_for_final_sampling
,
&
target_probs
,
&
target_probs
,
&
draft_probs
,
&
draft_probs
,
&
batch_size
,
&
batch_size
,
...
@@ -209,7 +216,8 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
...
@@ -209,7 +216,8 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
VEC_SIZE
,
VEC_SIZE
,
DETERMINISTIC
,
DETERMINISTIC
,
DType
,
DType
,
IdType
>
;
IdType
,
IdType2
>
;
FLASHINFER_CUDA_CALL
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FLASHINFER_CUDA_CALL
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel
,
nblks
,
nthrs
,
args
,
smem_size
,
stream
));
})});
})});
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
cfceb83d
...
@@ -331,6 +331,7 @@ void tree_speculative_sampling_target_only(
...
@@ -331,6 +331,7 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
uniform_samples_for_final_sampling
,
at
::
Tensor
target_probs
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
at
::
Tensor
draft_probs
,
double
threshold_single
=
1
,
double
threshold_single
=
1
,
...
@@ -363,7 +364,12 @@ void build_tree_kernel_efficient(
...
@@ -363,7 +364,12 @@ void build_tree_kernel_efficient(
int64_t
draft_token_num
);
int64_t
draft_token_num
);
void
segment_packbits
(
void
segment_packbits
(
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
cuda_stream
);
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
batch_size
,
int64_t
cuda_stream
=
0
);
/*
/*
* From FlashInfer
* From FlashInfer
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
cfceb83d
...
@@ -72,6 +72,7 @@ from sgl_kernel.speculative import (
...
@@ -72,6 +72,7 @@ from sgl_kernel.speculative import (
tree_speculative_sampling_target_only
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
verify_tree_greedy
,
)
)
from
sgl_kernel.top_k
import
fast_topk
from
sgl_kernel.version
import
__version__
from
sgl_kernel.version
import
__version__
build_tree_kernel
=
(
build_tree_kernel
=
(
...
...
sgl-kernel/python/sgl_kernel/speculative.py
View file @
cfceb83d
...
@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
...
@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
retrive_next_token
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
uniform_samples_for_final_sampling
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
threshold_single
:
float
=
1.0
,
threshold_single
:
float
=
1.0
,
...
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
...
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
retrive_next_token
,
retrive_next_token
,
retrive_next_sibling
,
retrive_next_sibling
,
uniform_samples
,
uniform_samples
,
uniform_samples_for_final_sampling
,
target_probs
,
target_probs
,
draft_probs
,
draft_probs
,
threshold_single
,
threshold_single
,
...
@@ -91,11 +93,13 @@ def segment_packbits(
...
@@ -91,11 +93,13 @@ def segment_packbits(
input_indptr
:
torch
.
Tensor
,
input_indptr
:
torch
.
Tensor
,
output_indptr
:
torch
.
Tensor
,
output_indptr
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
batch_size
:
int
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
segment_packbits
.
default
(
torch
.
ops
.
sgl_kernel
.
segment_packbits
.
default
(
x
,
x
,
input_indptr
,
input_indptr
,
output_indptr
,
output_indptr
,
y
,
y
,
batch_size
,
torch
.
cuda
.
current_stream
().
cuda_stream
,
torch
.
cuda
.
current_stream
().
cuda_stream
,
)
)
sgl-kernel/python/sgl_kernel/top_k.py
0 → 100644
View file @
cfceb83d
import
torch
def
fast_topk
(
values
,
topk
,
dim
):
if
topk
==
1
:
# Use max along the specified dimension to get both value and index
return
torch
.
max
(
values
,
dim
=
dim
,
keepdim
=
True
)
else
:
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
sgl-kernel/tests/speculative/test_eagle_utils.py
View file @
cfceb83d
...
@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
...
@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
],
[
7
,
8
,
9
,
10
,
11
,
12
],
[
7
,
8
,
9
,
10
,
11
,
12
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
retrive_index
=
torch
.
tensor
(
retrive_index
=
torch
.
tensor
(
...
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
...
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
],
[
6
,
7
,
8
,
9
,
10
,
11
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
retrive_next_token
=
torch
.
tensor
(
retrive_next_token
=
torch
.
tensor
(
...
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
...
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
retrive_next_sibling
=
torch
.
tensor
(
retrive_next_sibling
=
torch
.
tensor
(
...
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
...
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
...
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
target_logits
[
i
][
j
][
18
]
=
10
target_logits
[
i
][
j
][
18
]
=
10
target_predict
=
torch
.
argmax
(
target_logits
,
dim
=-
1
)
.
to
(
torch
.
int32
)
target_predict
=
torch
.
argmax
(
target_logits
,
dim
=-
1
)
predict_shape
=
(
12
,)
predict_shape
=
(
12
,)
bs
=
candidates
.
shape
[
0
]
bs
=
candidates
.
shape
[
0
]
num_spec_step
=
4
num_spec_step
=
4
num_draft_tokens
=
candidates
.
shape
[
1
]
predicts
=
torch
.
full
(
predicts
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
...
...
sgl-kernel/tests/speculative/test_speculative_sampling.py
View file @
cfceb83d
...
@@ -42,7 +42,7 @@ def test_tree_speculative_sampling_target_only(
...
@@ -42,7 +42,7 @@ def test_tree_speculative_sampling_target_only(
[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
],
[
7
,
8
,
9
,
10
,
11
,
12
],
[
7
,
8
,
9
,
10
,
11
,
12
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
device
,
device
=
device
,
)
)
retrive_index
=
torch
.
tensor
(
retrive_index
=
torch
.
tensor
(
...
@@ -50,7 +50,7 @@ def test_tree_speculative_sampling_target_only(
...
@@ -50,7 +50,7 @@ def test_tree_speculative_sampling_target_only(
[
0
,
1
,
2
,
3
,
4
,
5
],
[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
],
[
6
,
7
,
8
,
9
,
10
,
11
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
device
,
device
=
device
,
)
)
retrive_next_token
=
torch
.
tensor
(
retrive_next_token
=
torch
.
tensor
(
...
@@ -58,7 +58,7 @@ def test_tree_speculative_sampling_target_only(
...
@@ -58,7 +58,7 @@ def test_tree_speculative_sampling_target_only(
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
device
,
device
=
device
,
)
)
retrive_next_sibling
=
torch
.
tensor
(
retrive_next_sibling
=
torch
.
tensor
(
...
@@ -66,7 +66,7 @@ def test_tree_speculative_sampling_target_only(
...
@@ -66,7 +66,7 @@ def test_tree_speculative_sampling_target_only(
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
],
],
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
device
=
device
,
device
=
device
,
)
)
...
@@ -95,6 +95,7 @@ def test_tree_speculative_sampling_target_only(
...
@@ -95,6 +95,7 @@ def test_tree_speculative_sampling_target_only(
target_probs
=
F
.
softmax
(
target_logits
/
expanded_temperature
,
dim
=-
1
)
target_probs
=
F
.
softmax
(
target_logits
/
expanded_temperature
,
dim
=-
1
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
device
)
draft_probs
=
torch
.
full_like
(
target_probs
,
0
,
dtype
=
torch
.
float32
,
device
=
device
)
coins
=
torch
.
rand
(
bs
,
num_draft_tokens
,
device
=
device
,
dtype
=
torch
.
float32
)
coins
=
torch
.
rand
(
bs
,
num_draft_tokens
,
device
=
device
,
dtype
=
torch
.
float32
)
coins_for_final_sampling
=
torch
.
rand
(
bs
,
device
=
device
).
to
(
torch
.
float32
)
tree_speculative_sampling_target_only
(
tree_speculative_sampling_target_only
(
predicts
=
predicts
,
predicts
=
predicts
,
...
@@ -105,6 +106,7 @@ def test_tree_speculative_sampling_target_only(
...
@@ -105,6 +106,7 @@ def test_tree_speculative_sampling_target_only(
retrive_next_token
=
retrive_next_token
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
retrive_next_sibling
=
retrive_next_sibling
,
uniform_samples
=
coins
,
uniform_samples
=
coins
,
uniform_samples_for_final_sampling
=
coins_for_final_sampling
,
target_probs
=
target_probs
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
threshold_single
,
threshold_single
=
threshold_single
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment