Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c08a717c
Unverified
Commit
c08a717c
authored
Apr 18, 2025
by
PGFLMG
Committed by
GitHub
Apr 17, 2025
Browse files
[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by:
zhyncs
<
me@zhyncs.com
>
parent
f13d65a7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
393 additions
and
133 deletions
+393
-133
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+2
-2
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+12
-17
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
+5
-1
sgl-kernel/csrc/speculative/speculative_sampling.cuh
sgl-kernel/csrc/speculative/speculative_sampling.cuh
+4
-4
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+16
-26
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+108
-8
sgl-kernel/python/sgl_kernel/sampling.py
sgl-kernel/python/sgl_kernel/sampling.py
+213
-38
sgl-kernel/tests/test_sampling.py
sgl-kernel/tests/test_sampling.py
+33
-37
No files found.
sgl-kernel/CMakeLists.txt
View file @
c08a717c
...
@@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm)
...
@@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm)
# flashinfer
# flashinfer
FetchContent_Declare
(
FetchContent_Declare
(
repo-flashinfer
repo-flashinfer
GIT_REPOSITORY https://github.com/
sgl-project
/flashinfer
GIT_REPOSITORY https://github.com/
flashinfer-ai
/flashinfer
.git
GIT_TAG
sgl-kernel
GIT_TAG
9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-flashinfer
)
FetchContent_Populate
(
repo-flashinfer
)
...
...
sgl-kernel/csrc/common_extension.cc
View file @
c08a717c
...
@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
/*
* From csrc/elementwise
* From csrc/elementwise
*/
*/
m
.
def
(
"rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps,
int cuda_stream
) -> ()"
);
m
.
def
(
"rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps,
bool enable_pdl
) -> ()"
);
m
.
impl
(
"rmsnorm"
,
torch
::
kCUDA
,
&
rmsnorm
);
m
.
impl
(
"rmsnorm"
,
torch
::
kCUDA
,
&
rmsnorm
);
m
.
def
(
"fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"
);
m
.
def
(
"fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps
, bool enable_pdl
) -> ()"
);
m
.
impl
(
"fused_add_rmsnorm"
,
torch
::
kCUDA
,
&
sgl_fused_add_rmsnorm
);
m
.
impl
(
"fused_add_rmsnorm"
,
torch
::
kCUDA
,
&
sgl_fused_add_rmsnorm
);
m
.
def
(
"gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps,
int cuda_stream
) -> ()"
);
m
.
def
(
"gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps,
bool enable_pdl
) -> ()"
);
m
.
impl
(
"gemma_rmsnorm"
,
torch
::
kCUDA
,
&
gemma_rmsnorm
);
m
.
impl
(
"gemma_rmsnorm"
,
torch
::
kCUDA
,
&
gemma_rmsnorm
);
m
.
def
(
"gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps,
int cuda_stream
) -> ()"
);
m
.
def
(
"gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps,
bool enable_pdl
) -> ()"
);
m
.
impl
(
"gemma_fused_add_rmsnorm"
,
torch
::
kCUDA
,
&
gemma_fused_add_rmsnorm
);
m
.
impl
(
"gemma_fused_add_rmsnorm"
,
torch
::
kCUDA
,
&
gemma_fused_add_rmsnorm
);
m
.
def
(
"silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"
);
m
.
def
(
"silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"
);
...
@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"bmm_fp8"
,
torch
::
kCUDA
,
&
bmm_fp8
);
m
.
impl
(
"bmm_fp8"
,
torch
::
kCUDA
,
&
bmm_fp8
);
m
.
def
(
m
.
def
(
"min_p_sampling_from_probs(Tensor probs, Tensor
uniform_samples, Tensor! sampl
es, Tensor? maybe_min_p_arr, float "
"min_p_sampling_from_probs(Tensor probs, Tensor
output, Tensor? maybe_indic
es, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic,
int cuda_stream
) -> ()"
);
"min_p_val, bool deterministic,
Generator? gen
) -> ()"
);
m
.
impl
(
"min_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
min_p_sampling_from_probs
);
m
.
impl
(
"min_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
min_p_sampling_from_probs
);
m
.
def
(
m
.
def
(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()"
);
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()"
);
m
.
impl
(
"top_k_renorm_probs"
,
torch
::
kCUDA
,
&
top_k_renorm_probs
);
m
.
impl
(
"top_k_renorm_probs"
,
torch
::
kCUDA
,
&
top_k_renorm_probs
);
m
.
def
(
m
.
def
(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()"
);
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()"
);
m
.
impl
(
"top_p_renorm_probs"
,
torch
::
kCUDA
,
&
top_p_renorm_probs
);
m
.
impl
(
"top_p_renorm_probs"
,
torch
::
kCUDA
,
&
top_p_renorm_probs
);
m
.
def
(
m
.
def
(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"
);
"cuda_stream) -> ()"
);
m
.
impl
(
"top_k_top_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
top_k_top_p_sampling_from_probs
);
m
.
impl
(
"top_k_top_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
top_k_top_p_sampling_from_probs
);
m
.
def
(
m
.
def
(
"top_p_sampling_from_probs(Tensor probs, Tensor
uniform_samples, Tensor! samples, Tensor! suc
ces
s
, Tensor? "
"top_p_sampling_from_probs(Tensor probs, Tensor
output, Tensor? maybe_indi
ces, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic,
int cuda_stream
) -> ()"
);
"maybe_top_p_arr, float top_p_val, bool deterministic,
Generator? gen
) -> ()"
);
m
.
impl
(
"top_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
top_p_sampling_from_probs
);
m
.
impl
(
"top_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
top_p_sampling_from_probs
);
/*
/*
...
...
sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
View file @
c08a717c
...
@@ -21,7 +21,8 @@ limitations under the License.
...
@@ -21,7 +21,8 @@ limitations under the License.
using
namespace
flashinfer
;
using
namespace
flashinfer
;
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
)
{
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
,
bool
enable_pdl
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
residual
);
CHECK_INPUT
(
residual
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
...
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
...
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
batch_size
,
batch_size
,
hidden_size
,
hidden_size
,
input
.
stride
(
0
),
residual
.
stride
(
0
),
eps
,
eps
,
enable_pdl
,
torch_current_stream
);
torch_current_stream
);
TORCH_CHECK
(
TORCH_CHECK
(
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
status
==
cudaSuccess
,
"FusedAddRMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
...
...
sgl-kernel/csrc/speculative/speculative_sampling.cuh
View file @
c08a717c
...
@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
...
@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
DType
threshold_acc
)
{
DType
threshold_acc
)
{
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
const
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
;
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
extern
__shared__
__align__
(
alignof
(
SamplingTempStorage
<
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>
))
uint8_t
smem_sampling
[];
uint8_t
smem_sampling
[];
auto
&
temp_storage
=
auto
&
temp_storage
=
reinterpret_cast
<
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>&>
(
smem_sampling
);
reinterpret_cast
<
SamplingTempStorage
<
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
>&>
(
smem_sampling
);
DType
prob_acc
=
0.0
;
DType
prob_acc
=
0.0
;
uint32_t
cur_prob_offset
=
bx
*
num_draft_tokens
*
d
;
uint32_t
cur_prob_offset
=
bx
*
num_draft_tokens
*
d
;
...
@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
...
@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
relu_q_minus_p_vec
[
j
]
=
max
(
q_vec
[
j
]
-
p_vec
[
j
],
DType
(
0
));
relu_q_minus_p_vec
[
j
]
=
max
(
q_vec
[
j
]
-
p_vec
[
j
],
DType
(
0
));
}
}
DeviceSamplingFromProb
<
VEC_SIZE
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
,
DETERMINISTIC
,
DType
>
(
DeviceSamplingFromProb
<
VEC_SIZE
,
BLOCK_THREADS
,
SCAN_ALGORITHM
,
REDUCE_ALGORITHM
,
DETERMINISTIC
>
(
i
,
d
,
[
&
](
DType
x
)
{
return
x
>
0
;
},
u
,
relu_q_minus_p_vec
,
aggregate_relu_q_minus_p
,
&
temp_storage
);
i
,
d
,
[
&
](
DType
x
)
{
return
x
>
0
;
},
u
,
relu_q_minus_p_vec
,
aggregate_relu_q_minus_p
,
&
temp_storage
);
if
(
aggregate_relu_q_minus_p
>
u
)
{
if
(
aggregate_relu_q_minus_p
>
u
)
{
break
;
break
;
...
@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
...
@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
constexpr
uint32_t
BLOCK_THREADS
=
1024
;
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
const
uint32_t
vec_size
=
std
::
gcd
(
16
/
sizeof
(
DType
),
d
);
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
DType
,
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
const
uint32_t
smem_size
=
sizeof
(
SamplingTempStorage
<
BLOCK_THREADS
,
SCAN_ALGO
,
REDUCE_ALGO
>
);
dim3
nblks
(
batch_size
);
dim3
nblks
(
batch_size
);
dim3
nthrs
(
BLOCK_THREADS
);
dim3
nthrs
(
BLOCK_THREADS
);
float
capped_threshold_acc
=
fmaxf
(
threshold_acc
,
1e-9
f
);
float
capped_threshold_acc
=
fmaxf
(
threshold_acc
,
1e-9
f
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
c08a717c
...
@@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
...
@@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
/*
/*
* From csrc/elementwise
* From csrc/elementwise
*/
*/
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
bool
enable_pdl
);
void
sgl_fused_add_rmsnorm
(
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
);
void
sgl_fused_add_rmsnorm
(
void
gemma_rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
torch
::
Tensor
input
,
torch
::
Tensor
residual
,
torch
::
Tensor
weight
,
double
eps
,
bool
enable_pdl
);
void
gemma_
fused_add_rmsnorm
(
void
gemma_
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
bool
enable_pdl
);
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
);
void
gemma_fused_add_rmsnorm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
,
bool
enable_pdl
);
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
silu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_tanh_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
...
@@ -254,48 +254,38 @@ void segment_packbits(
...
@@ -254,48 +254,38 @@ void segment_packbits(
*/
*/
void
min_p_sampling_from_probs
(
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
output
,
at
::
Tensor
sampl
es
,
std
::
optional
<
at
::
Tensor
>
maybe_indic
es
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
double
min_p_val
,
double
min_p_val
,
bool
deterministic
,
bool
deterministic
,
int64_t
cuda_stream
);
std
::
optional
<
at
::
Generator
>
gen
);
void
top_k_renorm_probs
(
void
top_k_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
int64_t
top_k_val
);
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
int64_t
top_k_val
,
int64_t
cuda_stream
);
void
top_p_renorm_probs
(
void
top_p_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
);
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
int64_t
cuda_stream
);
void
top_k_top_p_sampling_from_probs
(
void
top_k_top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
output
,
at
::
Tensor
samples
,
std
::
optional
<
at
::
Tensor
>
maybe_indices
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
double
top_k_val
,
double
top_k_val
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
double
top_p_val
,
bool
deterministic
,
bool
deterministic
,
int64_t
cuda_stream
);
std
::
optional
<
at
::
Generator
>
gen
);
void
top_p_sampling_from_probs
(
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
output
,
at
::
Tensor
samples
,
std
::
optional
<
at
::
Tensor
>
maybe_indices
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
double
top_p_val
,
bool
deterministic
,
bool
deterministic
,
int64_t
cuda_stream
);
std
::
optional
<
at
::
Generator
>
gen
);
namespace
flash
{
namespace
flash
{
/*
/*
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
c08a717c
...
@@ -11,17 +11,69 @@ def rmsnorm(
...
@@ -11,17 +11,69 @@ def rmsnorm(
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
enable_pdl
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
()
)
torch
.
ops
.
sgl_kernel
.
rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
enable_pdl
)
return
out
return
out
def
fused_add_rmsnorm
(
def
fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
enable_pdl
:
bool
=
False
,
)
->
None
:
)
->
None
:
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
)
r
"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
torch
.
ops
.
sgl_kernel
.
fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
,
enable_pdl
)
def
gemma_rmsnorm
(
def
gemma_rmsnorm
(
...
@@ -29,20 +81,68 @@ def gemma_rmsnorm(
...
@@ -29,20 +81,68 @@ def gemma_rmsnorm(
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
enable_pdl
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
out
=
torch
.
empty_like
(
input
)
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
.
default
(
torch
.
ops
.
sgl_kernel
.
gemma_rmsnorm
.
default
(
out
,
input
,
weight
,
eps
,
enable_pdl
)
out
,
input
,
weight
,
eps
,
get_cuda_stream
()
)
return
out
return
out
def
gemma_fused_add_rmsnorm
(
def
gemma_fused_add_rmsnorm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
enable_pdl
:
bool
=
False
,
)
->
None
:
)
->
None
:
r
"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
.
default
(
torch
.
ops
.
sgl_kernel
.
gemma_fused_add_rmsnorm
.
default
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
()
input
,
residual
,
weight
,
eps
,
enable_pdl
)
)
...
...
sgl-kernel/python/sgl_kernel/sampling.py
View file @
c08a717c
...
@@ -13,11 +13,7 @@ def _top_k_renorm_probs_internal(
...
@@ -13,11 +13,7 @@ def _top_k_renorm_probs_internal(
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
.
default
(
torch
.
ops
.
sgl_kernel
.
top_k_renorm_probs
.
default
(
probs
,
probs
,
renorm_probs
,
maybe_top_k_arr
,
top_k_val
renorm_probs
,
maybe_top_k_arr
,
top_k_val
,
get_cuda_stream
(),
)
)
return
renorm_probs
return
renorm_probs
...
@@ -26,6 +22,30 @@ def top_k_renorm_probs(
...
@@ -26,6 +22,30 @@ def top_k_renorm_probs(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_k
:
Union
[
torch
.
Tensor
,
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for re-normalizing probabilities, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
"""
return
_top_k_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_k
))
return
_top_k_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_k
))
...
@@ -41,11 +61,7 @@ def _top_p_renorm_probs_internal(
...
@@ -41,11 +61,7 @@ def _top_p_renorm_probs_internal(
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
maybe_top_p_arr
=
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
renorm_probs
=
torch
.
empty_like
(
probs
)
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
.
default
(
torch
.
ops
.
sgl_kernel
.
top_p_renorm_probs
.
default
(
probs
,
probs
,
renorm_probs
,
maybe_top_p_arr
,
top_p_val
renorm_probs
,
maybe_top_p_arr
,
top_p_val
,
get_cuda_stream
(),
)
)
return
renorm_probs
return
renorm_probs
...
@@ -54,6 +70,32 @@ def top_p_renorm_probs(
...
@@ -54,6 +70,32 @@ def top_p_renorm_probs(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
re-normalizing probabilities, should be in ``(0, 1)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
"""
return
_top_p_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_p
))
return
_top_p_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_p
))
...
@@ -62,93 +104,187 @@ top_p_renorm_prob = top_p_renorm_probs
...
@@ -62,93 +104,187 @@ top_p_renorm_prob = top_p_renorm_probs
def
_top_p_sampling_from_probs_internal
(
def
_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
indices
:
Optional
[
torch
.
Tensor
]
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
top_p_val
:
float
,
deterministic
:
bool
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
.
default
(
torch
.
ops
.
sgl_kernel
.
top_p_sampling_from_probs
.
default
(
probs
,
probs
,
uniform_samples
,
samples
,
samples
,
suc
ces
s
,
indi
ces
,
maybe_top_p_arr
,
maybe_top_p_arr
,
top_p_val
,
top_p_val
,
deterministic
,
deterministic
,
ge
t_cuda_stream
()
,
ge
nerator
,
)
)
return
samples
,
success
return
samples
def
top_p_sampling_from_probs
(
def
top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
indices
:
Optional
[
torch
.
Tensor
]
=
None
,
deterministic
:
bool
=
True
,
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if
check_nan
:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_p_sampling_from_probs_internal
(
return
_top_p_sampling_from_probs_internal
(
probs
,
uniform_sampl
es
,
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
probs
,
indic
es
,
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
,
generator
)
)
def
_top_k_top_p_sampling_from_probs_internal
(
def
_top_k_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
indices
:
Optional
[
torch
.
Tensor
]
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
top_k_val
:
int
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
top_p_val
:
float
,
deterministic
:
bool
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_p_arr
=
(
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
.
default
(
torch
.
ops
.
sgl_kernel
.
top_k_top_p_sampling_from_probs
.
default
(
probs
,
probs
,
uniform_samples
,
samples
,
samples
,
suc
ces
s
,
indi
ces
,
maybe_top_k_arr
,
maybe_top_k_arr
,
top_k_val
,
top_k_val
,
maybe_top_p_arr
,
maybe_top_p_arr
,
top_p_val
,
top_p_val
,
deterministic
,
deterministic
,
ge
t_cuda_stream
()
,
ge
nerator
,
)
)
return
samples
,
success
return
samples
def
top_k_top_p_sampling_from_probs
(
def
top_k_top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
indices
:
Optional
[
torch
.
Tensor
]
=
None
,
filter_apply_order
:
str
=
"top_k_first"
,
filter_apply_order
:
str
=
"top_k_first"
,
deterministic
:
bool
=
True
,
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if
filter_apply_order
==
"top_k_first"
:
if
filter_apply_order
==
"top_k_first"
:
renorm_probs
=
top_k_renorm_probs
(
probs
,
top_k
)
renorm_probs
=
top_k_renorm_probs
(
probs
,
top_k
)
return
top_p_sampling_from_probs
(
return
top_p_sampling_from_probs
(
renorm_probs
,
uniform_samples
,
top_p
,
deterministic
,
check_nan
=
check_nan
renorm_probs
,
top_p
,
indices
,
deterministic
,
check_nan
=
check_nan
,
generator
=
generator
,
)
)
elif
filter_apply_order
==
"joint"
:
elif
filter_apply_order
==
"joint"
:
if
check_nan
:
if
check_nan
:
...
@@ -156,10 +292,11 @@ def top_k_top_p_sampling_from_probs(
...
@@ -156,10 +292,11 @@ def top_k_top_p_sampling_from_probs(
raise
ValueError
(
"Input probs contains NaN."
)
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_k_top_p_sampling_from_probs_internal
(
return
_top_k_top_p_sampling_from_probs_internal
(
probs
,
probs
,
uniform_sampl
es
,
indic
es
,
*
_to_tensor_scalar_tuple
(
top_k
),
*
_to_tensor_scalar_tuple
(
top_k
),
*
_to_tensor_scalar_tuple
(
top_p
),
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
,
deterministic
,
generator
,
)
)
else
:
else
:
raise
ValueError
(
f
"Invalid filter_apply_order:
{
filter_apply_order
}
"
)
raise
ValueError
(
f
"Invalid filter_apply_order:
{
filter_apply_order
}
"
)
...
@@ -167,44 +304,82 @@ def top_k_top_p_sampling_from_probs(
...
@@ -167,44 +304,82 @@ def top_k_top_p_sampling_from_probs(
def
_min_p_sampling_from_probs_internal
(
def
_min_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
indices
:
Optional
[
torch
.
Tensor
]
,
maybe_min_p_arr
:
Optional
[
torch
.
Tensor
],
maybe_min_p_arr
:
Optional
[
torch
.
Tensor
],
min_p_val
:
float
,
min_p_val
:
float
,
deterministic
:
bool
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_min_p_arr
=
(
maybe_min_p_arr
=
(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
.
default
(
torch
.
ops
.
sgl_kernel
.
min_p_sampling_from_probs
.
default
(
probs
,
probs
,
uniform_samples
,
samples
,
samples
,
indices
,
maybe_min_p_arr
,
maybe_min_p_arr
,
min_p_val
,
min_p_val
,
deterministic
,
deterministic
,
ge
t_cuda_stream
()
,
ge
nerator
,
)
)
return
samples
return
samples
def
min_p_sampling_from_probs
(
def
min_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
min_p
:
Union
[
torch
.
Tensor
,
float
],
min_p
:
Union
[
torch
.
Tensor
,
float
],
indices
:
Optional
[
torch
.
Tensor
]
=
None
,
deterministic
:
bool
=
True
,
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
check_nan
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
uniform_samples
.
dim
()
==
2
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
# Take the first row (round) of uniform_samples
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
uniform_samples
=
uniform_samples
[
0
]
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
min_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if
check_nan
:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
raise
ValueError
(
"Input probs contains NaN."
)
return
_min_p_sampling_from_probs_internal
(
return
_min_p_sampling_from_probs_internal
(
probs
,
uniform_sampl
es
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
probs
,
indic
es
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
,
generator
)
)
sgl-kernel/tests/test_sampling.py
View file @
c08a717c
...
@@ -5,8 +5,8 @@ import sgl_kernel
...
@@ -5,8 +5,8 @@ import sgl_kernel
import
torch
import
torch
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.1
,
0.5
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.1
,
0.5
])
def
test_top_k_top_p_joint_sampling_from_probs
(
batch_size
,
vocab_size
,
p
):
def
test_top_k_top_p_joint_sampling_from_probs
(
batch_size
,
vocab_size
,
p
):
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -16,14 +16,13 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
...
@@ -16,14 +16,13 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
k
=
int
(
vocab_size
*
0.1
)
k
=
int
(
vocab_size
*
0.1
)
else
:
else
:
raise
ValueError
(
"p not recognized"
)
raise
ValueError
(
"p not recognized"
)
max_top_k_trails
=
32
eps
=
1e-4
eps
=
1e-4
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
"cuda:0"
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# top-p mask
# top-p mask
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
cdf
=
torch
.
cumsum
(
sorted_prob
,
dim
=-
1
)
cdf
=
torch
.
cumsum
(
sorted_prob
,
dim
=-
1
)
mask_top_p
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
).
to
(
0
)
mask_top_p
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
,
device
=
"cuda:0"
)
mask_top_p
.
scatter_add_
(
1
,
indices
,
(
cdf
>
(
1
-
p
)
-
eps
).
int
())
mask_top_p
.
scatter_add_
(
1
,
indices
,
(
cdf
>
(
1
-
p
)
-
eps
).
int
())
# top-k mask
# top-k mask
sorted_prob
,
_
=
torch
.
sort
(
normalized_prob
,
descending
=
True
)
sorted_prob
,
_
=
torch
.
sort
(
normalized_prob
,
descending
=
True
)
...
@@ -31,40 +30,35 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
...
@@ -31,40 +30,35 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
mask_top_k
=
(
normalized_prob
>=
pivot
.
unsqueeze
(
-
1
)).
int
()
mask_top_k
=
(
normalized_prob
>=
pivot
.
unsqueeze
(
-
1
)).
int
()
# overall mask
# overall mask
mask
=
torch
.
minimum
(
mask_top_p
,
mask_top_k
)
mask
=
torch
.
minimum
(
mask_top_p
,
mask_top_k
)
uniform_samples
=
torch
.
empty
(
max_top_k_trails
,
batch_size
,
dtype
=
torch
.
float32
).
to
(
top_p_tensor
=
torch
.
full
((
batch_size
,),
p
,
device
=
"cuda:0"
)
0
top_k_tensor
=
torch
.
full
((
batch_size
,),
k
,
device
=
"cuda:0"
)
)
top_p_tensor
=
torch
.
full
((
batch_size
,),
p
).
to
(
0
)
top_k_tensor
=
torch
.
full
((
batch_size
,),
k
).
to
(
0
)
num_trails
=
1000
num_trails
=
1000
for
_
in
range
(
num_trails
):
for
_
in
range
(
num_trails
):
uniform_samples
.
uniform_
()
samples
=
sgl_kernel
.
top_k_top_p_sampling_from_probs
(
samples
,
success
=
sgl_kernel
.
top_k_top_p_sampling_from_probs
(
normalized_prob
,
normalized_prob
,
uniform_samples
,
top_k_tensor
,
top_k_tensor
,
top_p_tensor
,
top_p_tensor
,
filter_apply_order
=
"joint"
,
filter_apply_order
=
"joint"
,
)
)
assert
torch
.
all
(
success
)
assert
torch
.
all
(
samples
<
vocab_size
)
and
torch
.
all
(
samples
>=
0
)
assert
torch
.
all
(
samples
<
vocab_size
)
and
torch
.
all
(
samples
>=
0
)
assert
torch
.
all
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
1
),
normalized_prob
[
assert
torch
.
all
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
1
),
normalized_prob
[
torch
.
arange
(
batch_size
),
samples
torch
.
arange
(
batch_size
),
samples
]
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.1
,
0.5
,
0.9
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.1
,
0.5
,
0.9
])
def
test_top_p_renorm_probs
(
batch_size
,
vocab_size
,
p
):
def
test_top_p_renorm_probs
(
batch_size
,
vocab_size
,
p
):
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
torch
.
manual_seed
(
42
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
"cuda:0"
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
cdf
=
torch
.
cumsum
(
sorted_prob
,
dim
=-
1
)
cdf
=
torch
.
cumsum
(
sorted_prob
,
dim
=-
1
)
mask
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
).
to
(
0
)
mask
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
,
device
=
"cuda:0"
)
mask
.
scatter_add_
(
1
,
indices
,
(
cdf
>=
(
1
-
p
)).
int
())
mask
.
scatter_add_
(
1
,
indices
,
(
cdf
>=
(
1
-
p
)).
int
())
renorm_prob_ground_truth
=
normalized_prob
renorm_prob_ground_truth
=
normalized_prob
.
clone
()
renorm_prob_ground_truth
[
mask
==
0
]
=
0
renorm_prob_ground_truth
[
mask
==
0
]
=
0
renorm_prob_ground_truth
=
renorm_prob_ground_truth
/
renorm_prob_ground_truth
.
sum
(
renorm_prob_ground_truth
=
renorm_prob_ground_truth
/
renorm_prob_ground_truth
.
sum
(
dim
=-
1
,
keepdim
=
True
dim
=-
1
,
keepdim
=
True
...
@@ -79,56 +73,54 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p):
...
@@ -79,56 +73,54 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p):
)
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
10
,
100
,
500
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
10
,
100
,
500
])
def
test_top_k_renorm_probs
(
batch_size
,
vocab_size
,
k
):
def
test_top_k_renorm_probs
(
batch_size
,
vocab_size
,
k
):
if
k
>
vocab_size
:
if
k
>
vocab_size
:
pytest
.
skip
(
"k should be less than vocab_size"
)
pytest
.
skip
(
"k should be less than vocab_size"
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
"cuda:0"
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
sorted_prob
,
_
=
torch
.
sort
(
normalized_prob
,
descending
=
True
)
sorted_prob
,
_
=
torch
.
sort
(
normalized_prob
,
descending
=
True
)
pivot
=
sorted_prob
[:,
k
-
1
]
pivot
=
sorted_prob
[:,
k
-
1
]
mask
=
(
normalized_prob
>=
pivot
.
unsqueeze
(
-
1
)).
int
()
mask
=
(
normalized_prob
>=
pivot
.
unsqueeze
(
-
1
)).
int
()
renorm_prob_ground_truth
=
normalized_prob
renorm_prob_ground_truth
=
normalized_prob
.
clone
()
renorm_prob_ground_truth
[
mask
==
0
]
=
0
renorm_prob_ground_truth
[
mask
==
0
]
=
0
renorm_prob_ground_truth
=
renorm_prob_ground_truth
/
renorm_prob_ground_truth
.
sum
(
renorm_prob_ground_truth
=
renorm_prob_ground_truth
/
renorm_prob_ground_truth
.
sum
(
dim
=-
1
,
keepdim
=
True
dim
=-
1
,
keepdim
=
True
)
)
renorm_prob
=
sgl_kernel
.
top_k_renorm_prob
(
normalized_prob
,
k
)
renorm_prob
=
sgl_kernel
.
top_k_renorm_prob
(
normalized_prob
,
k
)
torch
.
testing
.
assert_close
(
for
i
in
range
(
batch_size
):
renorm_prob_ground_truth
,
torch
.
testing
.
assert_close
(
renorm_prob
,
renorm_prob_ground_truth
[
i
],
rtol
=
1e-3
,
renorm_prob
[
i
],
atol
=
1e-3
,
rtol
=
1e-3
,
)
atol
=
1e-3
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.05
,
0.1
,
0.2
,
0.7
,
1
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.05
,
0.1
,
0.2
,
0.7
,
1
])
def
test_min_p_sampling
(
batch_size
,
vocab_size
,
p
):
def
test_min_p_sampling
(
batch_size
,
vocab_size
,
p
):
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
"cuda:0"
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
# scale min-p
# scale min-p
top_probs
=
sorted_prob
[:,
-
1
].
unsqueeze
(
-
1
)
top_probs
=
sorted_prob
[:,
-
1
].
unsqueeze
(
-
1
)
scaled_p
=
p
*
top_probs
scaled_p
=
p
*
top_probs
# min-p mask
# min-p mask
mask
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
).
to
(
0
)
mask
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
,
device
=
"cuda:0"
)
mask
.
scatter_add_
(
1
,
indices
,
(
sorted_prob
>=
scaled_p
).
int
())
mask
.
scatter_add_
(
1
,
indices
,
(
sorted_prob
>=
scaled_p
).
int
())
uniform_samples
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
float32
).
to
(
0
)
min_p_tensor
=
torch
.
full
((
batch_size
,),
p
,
device
=
"cuda:0"
)
min_p_tensor
=
torch
.
full
((
batch_size
,),
p
).
to
(
0
)
num_trails
=
1000
num_trails
=
1000
for
_
in
range
(
num_trails
):
for
_
in
range
(
num_trails
):
uniform_samples
.
uniform_
()
samples
=
sgl_kernel
.
min_p_sampling_from_probs
(
samples
=
sgl_kernel
.
min_p_sampling_from_probs
(
normalized_prob
,
normalized_prob
,
uniform_samples
,
min_p_tensor
,
min_p_tensor
,
)
)
...
@@ -136,6 +128,10 @@ def test_min_p_sampling(batch_size, vocab_size, p):
...
@@ -136,6 +128,10 @@ def test_min_p_sampling(batch_size, vocab_size, p):
torch
.
nonzero
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
0
)
torch
.
nonzero
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
0
)
]
]
assert
torch
.
all
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
1
),
samples
[
torch
.
nonzero
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
0
)
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment