Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52a34d74
"vscode:/vscode.git/clone" did not exist on "244b66bb842dd06eaa8fddb753a53f2731a3b0a2"
Unverified
Commit
52a34d74
authored
Mar 16, 2025
by
Ying Sheng
Committed by
GitHub
Mar 16, 2025
Browse files
Add greedy verification kernel (#4383)
parent
06d12b39
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
392 additions
and
151 deletions
+392
-151
sgl-kernel/csrc/speculative/eagle_utils.cu
sgl-kernel/csrc/speculative/eagle_utils.cu
+129
-98
sgl-kernel/csrc/speculative/packbit.cu
sgl-kernel/csrc/speculative/packbit.cu
+47
-0
sgl-kernel/csrc/speculative/speculative_sampling.cu
sgl-kernel/csrc/speculative/speculative_sampling.cu
+9
-2
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+13
-5
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+11
-8
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+17
-12
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+6
-1
sgl-kernel/python/sgl_kernel/speculative.py
sgl-kernel/python/sgl_kernel/speculative.py
+38
-20
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/tests/speculative/test_eagle_utils.py
sgl-kernel/tests/speculative/test_eagle_utils.py
+98
-0
sgl-kernel/tests/speculative/test_speculative_sampling.py
sgl-kernel/tests/speculative/test_speculative_sampling.py
+23
-5
No files found.
sgl-kernel/csrc/speculative/eagle_utils.cu
View file @
52a34d74
...
...
@@ -17,6 +17,8 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "pytorch_extension_utils.h"
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
...
...
@@ -72,8 +74,8 @@ __global__ void build_tree_efficient(
}
if
(
parent_position
==
draft_token_num
)
{
printf
(
"
ERROR
: invalid eagle tree!!! Detected a token with no parent token selected.
Check the logprob. The token
"
"
will be dropped.
"
);
"
WARNING
: invalid eagle tree!!! Detected a token with no parent token selected. "
"
Please check if the logprob has nan. The token will be ignored to keep proceeding.
\n
"
);
continue
;
}
...
...
@@ -140,112 +142,141 @@ void build_tree_kernel_efficient(
int32_t
(
draft_token_num
));
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__
void
build_tree
(
int64_t
*
parent_list
,
int64_t
*
selected_index
,
int32_t
*
verified_seq_len
,
bool
*
tree_mask
,
int64_t
*
positions
,
int64_t
*
retrive_index
,
int
topk
,
int
depth
,
int
draft_token_num
)
{
int
bid
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
if
(
tid
>=
draft_token_num
)
{
return
;
}
int
seq_tree_idx
=
draft_token_num
*
draft_token_num
*
bid
;
for
(
int
i
=
0
;
i
<
bid
;
i
++
)
{
seq_tree_idx
+=
verified_seq_len
[
i
]
*
draft_token_num
;
}
int
seq_len
=
verified_seq_len
[
bid
];
int
token_tree_idx
=
seq_tree_idx
+
(
seq_len
+
draft_token_num
)
*
tid
+
seq_len
+
1
;
for
(
int
i
=
0
;
i
<
draft_token_num
-
1
;
i
++
)
{
tree_mask
[
token_tree_idx
+
i
]
=
false
;
}
template
<
typename
IdType
>
__global__
void
VerifyTreeGreedy
(
IdType
*
predicts
,
IdType
*
accept_index
,
IdType
*
accept_token_num
,
// mutable
IdType
*
candidates
,
IdType
*
retrive_index
,
IdType
*
retrive_next_token
,
IdType
*
retrive_next_sibling
,
IdType
*
target_predict
,
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
)
{
uint32_t
bx
=
blockIdx
.
x
;
int
position
=
0
;
if
(
tid
==
0
)
{
positions
[
bid
*
draft_token_num
]
=
seq_len
;
retrive_index
[
bid
*
draft_token_num
*
(
depth
+
2
)]
=
bid
*
draft_token_num
;
return
;
}
IdType
last_accepted_retrive_idx
=
retrive_index
[
bx
*
num_draft_tokens
];
accept_index
[
bx
*
num_speculative_tokens
]
=
last_accepted_retrive_idx
;
uint32_t
num_accepted_tokens
=
0
;
IdType
cur_index
=
0
;
int
depends_order
[
10
];
for
(
uint32_t
j
=
1
;
j
<
num_speculative_tokens
;
++
j
)
{
cur_index
=
retrive_next_token
[
bx
*
num_draft_tokens
+
cur_index
];
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
target_token_id
=
target_predict
[
last_accepted_retrive_idx
];
int
cur_position
=
tid
-
1
;
while
(
true
)
{
depends_order
[
position
]
=
cur_position
+
1
;
position
+=
1
;
tree_mask
[
token_tree_idx
+
cur_position
]
=
true
;
int
parent_tb_idx
=
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
/
topk
;
if
(
parent_tb_idx
==
0
)
{
if
(
draft_token_id
==
target_token_id
)
{
// accept token
predicts
[
last_accepted_retrive_idx
]
=
target_token_id
;
++
num_accepted_tokens
;
accept_index
[
bx
*
num_speculative_tokens
+
num_accepted_tokens
]
=
draft_index
;
last_accepted_retrive_idx
=
draft_index
;
break
;
}
else
{
cur_index
=
retrive_next_sibling
[
bx
*
num_draft_tokens
+
cur_index
];
}
}
if
(
cur_index
==
-
1
)
break
;
}
accept_token_num
[
bx
]
=
num_accepted_tokens
;
predicts
[
last_accepted_retrive_idx
]
=
target_predict
[
last_accepted_retrive_idx
];
}
int
token_idx
=
parent_list
[
bid
*
(
topk
*
(
depth
-
1
)
+
1
)
+
parent_tb_idx
];
for
(
cur_position
=
0
;
cur_position
<
draft_token_num
;
cur_position
++
)
{
if
(
selected_index
[
bid
*
(
draft_token_num
-
1
)
+
cur_position
]
==
token_idx
)
{
break
;
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// target_predict: [bs, num_draft_tokens]
void
verify_tree_greedy
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
CHECK_INPUT
(
retrive_next_token
);
CHECK_INPUT
(
retrive_next_sibling
);
CHECK_INPUT
(
target_predict
);
auto
device
=
target_predict
.
device
();
CHECK_EQ
(
candidates
.
device
(),
device
);
CHECK_EQ
(
retrive_index
.
device
(),
device
);
CHECK_EQ
(
retrive_next_token
.
device
(),
device
);
CHECK_EQ
(
retrive_next_sibling
.
device
(),
device
);
CHECK_EQ
(
target_predict
.
device
(),
device
);
CHECK_DIM
(
1
,
predicts
);
CHECK_DIM
(
2
,
accept_index
);
CHECK_DIM
(
1
,
accept_token_num
);
CHECK_DIM
(
2
,
candidates
);
CHECK_DIM
(
2
,
retrive_index
);
CHECK_DIM
(
2
,
retrive_next_token
);
CHECK_DIM
(
2
,
retrive_next_sibling
);
CHECK_DIM
(
2
,
target_predict
);
unsigned
int
batch_size
=
candidates
.
size
(
0
);
unsigned
int
num_spec_step
=
accept_index
.
size
(
1
);
unsigned
int
num_draft_tokens
=
candidates
.
size
(
1
);
CHECK_EQ
(
batch_size
,
accept_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
accept_token_num
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_next_token
.
size
(
0
));
CHECK_EQ
(
batch_size
,
retrive_next_sibling
.
size
(
0
));
CHECK_EQ
(
batch_size
,
target_predict
.
size
(
0
));
CHECK_EQ
(
num_draft_tokens
,
retrive_index
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
retrive_next_token
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
retrive_next_sibling
.
size
(
1
));
CHECK_EQ
(
num_draft_tokens
,
target_predict
.
size
(
1
));
CHECK_EQ
(
batch_size
,
accept_index
.
size
(
0
));
CHECK_EQ
(
batch_size
,
accept_token_num
.
size
(
0
));
if
(
predicts
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'predicts' to be of type int (torch.int32)."
);
}
if
(
accept_index
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_index' to be of type int (torch.int32)."
);
}
if
(
cur_position
==
draft_token_num
)
{
printf
(
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
"will be dropped."
);
break
;
if
(
accept_token_num
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'accept_token_num' to be of type int (torch.int32)."
);
}
if
(
candidates
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'candidates' to be of type int (torch.int32)."
);
}
positions
[
bid
*
draft_token_num
+
tid
]
=
position
+
seq_len
;
int
is_leaf
=
0
;
for
(
int
i
=
1
;
i
<
draft_token_num
;
i
++
)
{
if
(
tree_mask
[
seq_tree_idx
+
i
*
(
draft_token_num
+
seq_len
)
+
seq_len
+
tid
])
{
is_leaf
++
;
if
(
retrive_index
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_index' to be of type int (torch.int32)."
);
}
if
(
retrive_next_token
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_token' to be of type int (torch.int32)."
);
}
if
(
is_leaf
==
1
)
{
for
(
int
i
=
0
;
i
<
position
;
i
++
)
{
retrive_index
[(
bid
*
(
draft_token_num
)
+
tid
)
*
(
depth
+
2
)
+
position
-
i
]
=
depends_order
[
i
]
+
bid
*
draft_token_num
;
if
(
retrive_next_sibling
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'retrive_next_sibling' to be of type int (torch.int32)."
);
}
retrive_index
[(
bid
*
(
draft_token_num
)
+
tid
)
*
(
depth
+
2
)]
=
bid
*
draft_token_num
;
if
(
target_predict
.
scalar_type
()
!=
at
::
kInt
)
{
throw
std
::
runtime_error
(
"Expected 'target_predict' to be of type int (torch.int32)."
);
}
}
void
build_tree_kernel
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
)
{
// TODO (ying) check shape
// TODO (ying) check type
int
bs
=
parent_list
.
size
(
0
);
dim3
grid
(
bs
);
dim3
block
(
draft_token_num
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
dim3
grid
(
batch_size
);
dim3
block
(
1
);
build_tree
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int64_t
*>
(
parent_list
.
data_ptr
()),
static_cast
<
int64_t
*>
(
selected_index
.
data_ptr
()),
static_cast
<
int32_t
*>
(
verified_seq_len
.
data_ptr
()),
static_cast
<
bool
*>
(
tree_mask
.
data_ptr
()),
static_cast
<
int64_t
*>
(
positions
.
data_ptr
()),
static_cast
<
int64_t
*>
(
retrive_index
.
data_ptr
()),
int32_t
(
topk
),
int32_t
(
depth
),
int32_t
(
draft_token_num
));
VerifyTreeGreedy
<
int
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
int
*>
(
predicts
.
data_ptr
()),
static_cast
<
int
*>
(
accept_index
.
data_ptr
()),
static_cast
<
int
*>
(
accept_token_num
.
data_ptr
()),
static_cast
<
int
*>
(
candidates
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_index
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_token
.
data_ptr
()),
static_cast
<
int
*>
(
retrive_next_sibling
.
data_ptr
()),
static_cast
<
int
*>
(
target_predict
.
data_ptr
()),
batch_size
,
num_spec_step
,
num_draft_tokens
);
}
sgl-kernel/csrc/speculative/packbit.cu
0 → 100644
View file @
52a34d74
// This is only a pluggin used for flashinfer 0.1.6. The new version does not need it.
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/quantization.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
// bitorder = "little"
void
segment_packbits
(
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
cuda_stream
)
{
CHECK_INPUT
(
x
);
CHECK_INPUT
(
input_indptr
);
CHECK_INPUT
(
output_indptr
);
auto
device
=
x
.
device
();
CHECK_EQ
(
input_indptr
.
device
(),
device
);
CHECK_EQ
(
output_indptr
.
device
(),
device
);
CHECK_EQ
(
y
.
device
(),
device
);
unsigned
int
batch_size
=
input_indptr
.
size
(
0
)
-
1
;
CHECK_EQ
(
output_indptr
.
size
(
0
),
batch_size
+
1
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
quantization
::
SegmentPackBits
(
static_cast
<
bool
*>
(
x
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
y
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_indptr
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_indptr
.
data_ptr
()),
batch_size
,
quantization
::
BitOrder
::
kLittle
,
stream
);
}
sgl-kernel/csrc/speculative/speculative_sampling.cu
View file @
52a34d74
...
...
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"
#include "speculative_sampling.cuh"
...
...
@@ -40,7 +39,9 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
bool
deterministic
,
double
threshold_single
,
double
threshold_acc
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
)
{
CHECK_INPUT
(
candidates
);
CHECK_INPUT
(
retrive_index
);
...
...
@@ -112,6 +113,10 @@ void tree_speculative_sampling_target_only(
if
(
draft_probs
.
scalar_type
()
!=
at
::
kFloat
)
{
throw
std
::
runtime_error
(
"Expected 'target_probs' to be of type float (torch.float32)."
);
}
CHECK_GE
(
threshold_single
,
0
);
CHECK_GE
(
1
,
threshold_single
);
CHECK_GE
(
threshold_acc
,
0
);
CHECK_GE
(
1
,
threshold_acc
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaError_t
status
=
sampling
::
TreeSpeculativeSamplingTargetOnly
<
float
,
int
>
(
...
...
@@ -129,6 +134,8 @@ void tree_speculative_sampling_target_only(
num_spec_step
,
num_draft_tokens
,
vocab_size
,
static_cast
<
float
>
(
threshold_single
),
static_cast
<
float
>
(
threshold_acc
),
deterministic
,
stream
);
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
52a34d74
...
...
@@ -49,7 +49,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
uint32_t
batch_size
,
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
)
{
uint32_t
d
,
DType
threshold_single
,
DType
threshold_acc
)
{
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
...
...
@@ -70,9 +72,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
while
(
cur_index
!=
-
1
)
{
IdType
draft_index
=
retrive_index
[
bx
*
num_draft_tokens
+
cur_index
];
IdType
draft_token_id
=
candidates
[
bx
*
num_draft_tokens
+
cur_index
];
prob_acc
+=
target_probs
[
cur_prob_offset
+
draft_token_id
];
DType
target_prob_single
=
target_probs
[
cur_prob_offset
+
draft_token_id
];
prob_acc
+=
target_prob_single
;
if
(
coin
<
prob_acc
)
{
if
(
coin
<
=
prob_acc
/
threshold_acc
||
target_prob_single
>=
threshold_single
)
{
// accept token
prob_acc
=
0.
;
cur_prob_offset
=
(
bx
*
num_draft_tokens
+
cur_index
)
*
d
;
...
...
@@ -169,7 +172,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
uint32_t
num_speculative_tokens
,
uint32_t
num_draft_tokens
,
uint32_t
d
,
bool
deterministic
,
DType
threshold_single
=
1
,
DType
threshold_acc
=
1
,
bool
deterministic
=
true
,
cudaStream_t
stream
=
0
)
{
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
...
...
@@ -177,6 +182,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
BLOCK_THREADS
);
float
capped_threshold_acc
=
fmaxf
(
threshold_acc
,
1e-9
f
);
void
*
args
[]
=
{
&
predicts
,
&
output_token_ids
,
...
...
@@ -191,7 +197,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
&
batch_size
,
&
num_speculative_tokens
,
&
num_draft_tokens
,
&
d
};
&
d
,
&
threshold_single
,
&
capped_threshold_acc
};
DISPATCH_ALIGNED_VEC_SIZE
(
vec_size
,
VEC_SIZE
,
{
DISPATCH_DETERMINISTIC
(
deterministic
,
DETERMINISTIC
,
{
auto
kernel
=
TreeSpeculativeSamplingTargetOnly
<
...
...
sgl-kernel/csrc/torch_extension.cc
View file @
52a34d74
...
...
@@ -129,21 +129,24 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()"
);
m
.
impl
(
"tree_speculative_sampling_target_only"
,
torch
::
kCUDA
,
&
tree_speculative_sampling_target_only
);
m
.
def
(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()"
);
m
.
impl
(
"verify_tree_greedy"
,
torch
::
kCUDA
,
&
verify_tree_greedy
);
m
.
def
(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
"retrive_next_sibling, "
"int topk, int depth, int draft_token_num) -> ()"
);
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
m
.
def
(
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
"int topk, int depth, int draft_token_num) -> ()"
);
m
.
impl
(
"build_tree_kernel"
,
torch
::
kCUDA
,
&
build_tree_kernel
);
m
.
def
(
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"
);
m
.
impl
(
"segment_packbits"
,
torch
::
kCUDA
,
&
segment_packbits
);
/*
* From FlashInfer
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
52a34d74
...
...
@@ -183,8 +183,8 @@ void topk_softmax(
* From csrc/speculative
*/
void
tree_speculative_sampling_target_only
(
at
::
Tensor
predicts
,
at
::
Tensor
accept_index
,
at
::
Tensor
predicts
,
// mutable
at
::
Tensor
accept_index
,
// mutable
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
...
...
@@ -193,33 +193,38 @@ void tree_speculative_sampling_target_only(
at
::
Tensor
uniform_samples
,
at
::
Tensor
target_probs
,
at
::
Tensor
draft_probs
,
double
threshold_single
=
1
,
double
threshold_acc
=
1
,
bool
deterministic
=
true
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
void
verify_tree_greedy
(
at
::
Tensor
predicts
,
// mutable
at
::
Tensor
accept_index
,
// mutable
at
::
Tensor
accept_token_num
,
// mutable
at
::
Tensor
candidates
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
at
::
Tensor
target_predict
,
int64_t
cuda_stream
=
0
);
void
build_tree_kernel
(
void
build_tree_kernel
_efficient
(
at
::
Tensor
parent_list
,
at
::
Tensor
selected_index
,
at
::
Tensor
verified_seq_len
,
at
::
Tensor
tree_mask
,
at
::
Tensor
positions
,
at
::
Tensor
retrive_index
,
at
::
Tensor
retrive_next_token
,
at
::
Tensor
retrive_next_sibling
,
int64_t
topk
,
int64_t
depth
,
int64_t
draft_token_num
);
void
segment_packbits
(
at
::
Tensor
x
,
at
::
Tensor
input_indptr
,
at
::
Tensor
output_indptr
,
at
::
Tensor
y
,
int64_t
cuda_stream
);
/*
* From FlashInfer
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
52a34d74
...
...
@@ -42,8 +42,13 @@ from sgl_kernel.sampling import (
top_p_sampling_from_probs
,
)
from
sgl_kernel.speculative
import
(
build_tree_kernel
,
build_tree_kernel_efficient
,
segment_packbits
,
tree_speculative_sampling_target_only
,
verify_tree_greedy
,
)
from
sgl_kernel.version
import
__version__
build_tree_kernel
=
(
None
# TODO(ying): remove this after updating the sglang python code.
)
sgl-kernel/python/sgl_kernel/speculative.py
View file @
52a34d74
...
...
@@ -13,6 +13,8 @@ def tree_speculative_sampling_target_only(
uniform_samples
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
threshold_single
:
float
=
1.0
,
threshold_acc
:
float
=
1.0
,
deterministic
:
bool
=
True
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
tree_speculative_sampling_target_only
(
...
...
@@ -26,58 +28,74 @@ def tree_speculative_sampling_target_only(
uniform_samples
,
target_probs
,
draft_probs
,
threshold_single
,
threshold_acc
,
deterministic
,
get_cuda_stream
(),
)
def
build_tree_kernel_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
verify_tree_greedy
(
predicts
:
torch
.
Tensor
,
# mutable
accept_index
:
torch
.
Tensor
,
# mutable
accept_token_num
:
torch
.
Tensor
,
# mutable
candidates
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
target_predict
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
torch
.
ops
.
sgl_kernel
.
verify_tree_greedy
(
predicts
,
accept_index
,
accept_token_num
,
candidates
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
target_predict
,
get_cuda_stream
(),
)
def
build_tree_kernel
(
def
build_tree_kernel
_efficient
(
parent_list
:
torch
.
Tensor
,
selected_index
:
torch
.
Tensor
,
verified_seq_len
:
torch
.
Tensor
,
tree_mask
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
retrive_index
:
torch
.
Tensor
,
retrive_next_token
:
torch
.
Tensor
,
retrive_next_sibling
:
torch
.
Tensor
,
topk
:
int
,
depth
:
int
,
draft_token_num
:
int
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
build_tree_kernel
(
torch
.
ops
.
sgl_kernel
.
build_tree_kernel
_efficient
(
parent_list
,
selected_index
,
verified_seq_len
,
tree_mask
,
positions
,
retrive_index
,
retrive_next_token
,
retrive_next_sibling
,
topk
,
depth
,
draft_token_num
,
)
def
segment_packbits
(
x
:
torch
.
Tensor
,
input_indptr
:
torch
.
Tensor
,
output_indptr
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
sgl_kernel
.
segment_packbits
(
x
,
input_indptr
,
output_indptr
,
y
,
torch
.
cuda
.
current_stream
().
cuda_stream
,
)
sgl-kernel/setup.py
View file @
52a34d74
...
...
@@ -209,6 +209,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
"csrc/speculative/packbit.cu"
,
"csrc/torch_extension.cc"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
...
...
sgl-kernel/tests/speculative/test_eagle_utils.py
0 → 100644
View file @
52a34d74
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
verify_tree_greedy
def
test_verify_tree_greedy
():
candidates
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
7
,
8
,
9
,
10
,
11
,
12
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_index
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_next_token
=
torch
.
tensor
(
[
[
1
,
2
,
-
1
,
4
,
5
,
-
1
],
[
4
,
2
,
3
,
-
1
,
5
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
retrive_next_sibling
=
torch
.
tensor
(
[
[
-
1
,
3
,
-
1
,
-
1
,
-
1
,
-
1
],
[
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
],
],
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
target_logits
=
torch
.
full
((
2
,
6
,
20
),
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
[
0
,
0
,
3
]
=
10
target_logits
[
0
,
3
,
4
]
=
10
target_logits
[
0
,
4
,
5
]
=
10
target_logits
[
1
,
0
,
11
]
=
10
target_logits
[
1
,
4
,
12
]
=
10
for
i
in
range
(
target_logits
.
shape
[
0
]):
for
j
in
range
(
target_logits
.
shape
[
1
]):
if
torch
.
max
(
target_logits
[
i
][
j
])
<
10
:
target_logits
[
i
][
j
][
18
]
=
10
print
(
f
"
{
target_logits
=
}
"
)
target_predict
=
torch
.
argmax
(
target_logits
,
dim
=-
1
).
to
(
torch
.
int32
)
predict_shape
=
(
12
,)
bs
=
candidates
.
shape
[
0
]
num_spec_step
=
4
num_draft_tokens
=
candidates
.
shape
[
1
]
predicts
=
torch
.
full
(
predict_shape
,
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_index
=
torch
.
full
(
(
bs
,
num_spec_step
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
accept_token_num
=
torch
.
full
((
bs
,),
0
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# mutable
print
(
f
"
{
candidates
=
}
"
)
print
(
f
"
{
retrive_index
=
}
"
)
print
(
f
"
{
retrive_next_token
=
}
"
)
print
(
f
"
{
retrive_next_sibling
=
}
"
)
print
(
f
"
{
target_predict
=
}
"
)
verify_tree_greedy
(
predicts
=
predicts
,
accept_index
=
accept_index
,
accept_token_num
=
accept_token_num
,
candidates
=
candidates
,
retrive_index
=
retrive_index
,
retrive_next_token
=
retrive_next_token
,
retrive_next_sibling
=
retrive_next_sibling
,
target_predict
=
target_predict
,
)
print
(
f
"
{
predicts
=
}
"
)
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
return
predicts
,
accept_index
,
accept_token_num
if
__name__
==
"__main__"
:
predicts
,
accept_index
,
accept_token_num
=
test_verify_tree_greedy
()
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
sgl-kernel/tests/test_speculative_sampling.py
→
sgl-kernel/tests/
speculative/
test_speculative_sampling.py
View file @
52a34d74
...
...
@@ -3,7 +3,10 @@ import torch.nn.functional as F
from
sgl_kernel
import
tree_speculative_sampling_target_only
def
test_tree_speculative_sampling_target_only
():
def
test_tree_speculative_sampling_target_only
(
threshold_single
=
1
,
threshold_acc
=
1
):
print
(
f
"
\n
============= run test:
{
threshold_single
=
}
{
threshold_acc
=
}
==============
\n
"
)
candidates
=
torch
.
tensor
(
[
[
0
,
1
,
2
,
3
,
4
,
5
],
...
...
@@ -37,7 +40,7 @@ def test_tree_speculative_sampling_target_only():
device
=
"cuda"
,
)
target_logits
=
torch
.
zeros
((
2
,
6
,
20
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
=
torch
.
full
((
2
,
6
,
20
),
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
target_logits
[
0
,
0
,
3
]
=
10
target_logits
[
0
,
3
,
4
]
=
10
target_logits
[
0
,
4
,
5
]
=
10
...
...
@@ -85,6 +88,8 @@ def test_tree_speculative_sampling_target_only():
uniform_samples
=
coins
,
target_probs
=
target_probs
,
draft_probs
=
draft_probs
,
threshold_single
=
threshold_single
,
threshold_acc
=
threshold_acc
,
deterministic
=
True
,
)
...
...
@@ -92,6 +97,13 @@ def test_tree_speculative_sampling_target_only():
print
(
f
"
{
accept_index
=
}
"
)
print
(
f
"
{
accept_token_num
=
}
"
)
return
predicts
,
accept_index
,
accept_token_num
if
__name__
==
"__main__"
:
predicts
,
accept_index
,
accept_token_num
=
(
test_tree_speculative_sampling_target_only
(
threshold_single
=
1
,
threshold_acc
=
1
)
)
assert
predicts
.
tolist
()
==
[
3
,
-
1
,
-
1
,
4
,
5
,
18
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
3
,
4
,
5
],
...
...
@@ -99,6 +111,12 @@ def test_tree_speculative_sampling_target_only():
]
assert
accept_token_num
.
tolist
()
==
[
3
,
2
]
if
__name__
==
"__main__"
:
test_tree_speculative_sampling_target_only
()
predicts
,
accept_index
,
accept_token_num
=
(
test_tree_speculative_sampling_target_only
(
threshold_single
=
0
,
threshold_acc
=
0
)
)
assert
predicts
.
tolist
()
==
[
1
,
2
,
18
,
-
1
,
-
1
,
-
1
,
11
,
-
1
,
-
1
,
-
1
,
12
,
18
]
assert
accept_index
.
tolist
()
==
[
[
0
,
1
,
2
,
-
1
],
[
6
,
10
,
11
,
-
1
],
]
assert
accept_token_num
.
tolist
()
==
[
2
,
2
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment