Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5de4051b
Unverified
Commit
5de4051b
authored
Jan 24, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 24, 2025
Browse files
feat: integrate sampling kernels into sgl-kernel (#3086)
Co-authored-by:
Zihao Ye
<
expye@outlook.com
>
parent
e0cd65c2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
419 additions
and
3 deletions
+419
-3
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+9
-1
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+34
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+227
-2
sgl-kernel/src/sgl-kernel/ops/utils.py
sgl-kernel/src/sgl-kernel/ops/utils.py
+7
-0
sgl-kernel/tests/test_sampling.py
sgl-kernel/tests/test_sampling.py
+141
-0
No files found.
sgl-kernel/setup.py
View file @
5de4051b
...
@@ -128,6 +128,7 @@ ext_modules = [
...
@@ -128,6 +128,7 @@ ext_modules = [
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
,
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
],
],
include_dirs
=
include_dirs
,
include_dirs
=
include_dirs
,
extra_compile_args
=
{
extra_compile_args
=
{
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
5de4051b
...
@@ -11,12 +11,16 @@ from sgl_kernel.ops import (
...
@@ -11,12 +11,16 @@ from sgl_kernel.ops import (
init_custom_reduce
,
init_custom_reduce
,
int8_scaled_mm
,
int8_scaled_mm
,
lightning_attention_decode
,
lightning_attention_decode
,
min_p_sampling_from_probs
,
moe_align_block_size
,
moe_align_block_size
,
register_graph_buffers
,
register_graph_buffers
,
rmsnorm
,
rmsnorm
,
rotary_embedding
,
rotary_embedding
,
sampling_scaling_penalties
,
sampling_scaling_penalties
,
silu_and_mul
,
silu_and_mul
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
)
__all__
=
[
__all__
=
[
...
@@ -31,11 +35,15 @@ __all__ = [
...
@@ -31,11 +35,15 @@ __all__ = [
"get_graph_buffer_ipc_meta"
,
"get_graph_buffer_ipc_meta"
,
"init_custom_reduce"
,
"init_custom_reduce"
,
"int8_scaled_mm"
,
"int8_scaled_mm"
,
"lightning_attention_decode"
,
"min_p_sampling_from_probs"
,
"moe_align_block_size"
,
"moe_align_block_size"
,
"register_graph_buffers"
,
"register_graph_buffers"
,
"rmsnorm"
,
"rmsnorm"
,
"rotary_embedding"
,
"rotary_embedding"
,
"sampling_scaling_penalties"
,
"sampling_scaling_penalties"
,
"lightning_attention_decode"
,
"silu_and_mul"
,
"silu_and_mul"
,
"top_k_renorm_prob"
,
"top_k_top_p_sampling_from_probs"
,
"top_p_renorm_prob"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
5de4051b
...
@@ -61,6 +61,30 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
...
@@ -61,6 +61,30 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
// min p sampling from probs
void
min_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
std
::
optional
<
at
::
Tensor
>
maybe_min_p_arr
,
double
min_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
// top k renorm probs
void
top_k_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
unsigned
int
top_k_val
,
int64_t
cuda_stream
);
// top p renorm probs
void
top_p_renorm_probs
(
at
::
Tensor
probs
,
at
::
Tensor
renorm_probs
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
int64_t
cuda_stream
);
// top k top p sampling from probs
void
top_k_top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_k_arr
,
double
top_k_val
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
// top p sampling from probs
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// trt_reduce
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
...
@@ -94,4 +118,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -94,4 +118,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Gelu and Mul (CUDA)"
);
m
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Gelu and Mul (CUDA)"
);
// bmm fp8
// bmm fp8
m
.
def
(
"bmm_fp8"
,
&
bmm_fp8
,
"BMM FP8 (CUDA)"
);
m
.
def
(
"bmm_fp8"
,
&
bmm_fp8
,
"BMM FP8 (CUDA)"
);
// min p sampling from probs
m
.
def
(
"min_p_sampling_from_probs"
,
&
min_p_sampling_from_probs
,
"Min P Sampling From Probs (CUDA)"
);
// top k renorm probs
m
.
def
(
"top_k_renorm_probs"
,
&
top_k_renorm_probs
,
"Top K Renorm Probs (CUDA)"
);
// top p renorm probs
m
.
def
(
"top_p_renorm_probs"
,
&
top_p_renorm_probs
,
"Top P Renorm Probs (CUDA)"
);
// top k top p sampling from probs
m
.
def
(
"top_k_top_p_sampling_from_probs"
,
&
top_k_top_p_sampling_from_probs
,
"Top K Top P Sampling From Probs (CUDA)"
);
// top p sampling from probs
m
.
def
(
"top_p_sampling_from_probs"
,
&
top_p_sampling_from_probs
,
"Top P Sampling From Probs (CUDA)"
);
}
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
5de4051b
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
...
@@ -17,6 +17,9 @@ from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
...
@@ -17,6 +17,9 @@ from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from
sgl_kernel.ops._kernels
import
(
from
sgl_kernel.ops._kernels
import
(
lightning_attention_decode
as
_lightning_attention_decode
,
lightning_attention_decode
as
_lightning_attention_decode
,
)
)
from
sgl_kernel.ops._kernels
import
(
min_p_sampling_from_probs
as
_min_p_sampling_from_probs
,
)
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
rmsnorm
as
_rmsnorm
from
sgl_kernel.ops._kernels
import
rmsnorm
as
_rmsnorm
...
@@ -25,7 +28,19 @@ from sgl_kernel.ops._kernels import (
...
@@ -25,7 +28,19 @@ from sgl_kernel.ops._kernels import (
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
)
)
from
sgl_kernel.ops._kernels
import
silu_and_mul
as
_silu_and_mul
from
sgl_kernel.ops._kernels
import
silu_and_mul
as
_silu_and_mul
from
sgl_kernel.ops.utils
import
_get_cache_buf
,
_get_cuda_stream
from
sgl_kernel.ops._kernels
import
top_k_renorm_probs
as
_top_k_renorm_probs
from
sgl_kernel.ops._kernels
import
(
top_k_top_p_sampling_from_probs
as
_top_k_top_p_sampling_from_probs
,
)
from
sgl_kernel.ops._kernels
import
top_p_renorm_probs
as
_top_p_renorm_probs
from
sgl_kernel.ops._kernels
import
(
top_p_sampling_from_probs
as
_top_p_sampling_from_probs
,
)
from
sgl_kernel.ops.utils
import
(
_get_cache_buf
,
_get_cuda_stream
,
_to_tensor_scalar_tuple
,
)
def
init_custom_reduce
(
def
init_custom_reduce
(
...
@@ -236,3 +251,213 @@ def bmm_fp8(
...
@@ -236,3 +251,213 @@ def bmm_fp8(
workspace_buffer
=
_get_cache_buf
(
"bmm_fp8_workspace"
,
32
*
1024
*
1024
,
A
.
device
)
workspace_buffer
=
_get_cache_buf
(
"bmm_fp8_workspace"
,
32
*
1024
*
1024
,
A
.
device
)
_bmm_fp8_internal
(
workspace_buffer
,
A
,
B
,
out
,
A_scale
,
B_scale
)
_bmm_fp8_internal
(
workspace_buffer
,
A
,
B
,
out
,
A_scale
,
B_scale
)
return
out
return
out
def
_top_k_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
renorm_probs
=
torch
.
empty_like
(
probs
)
_top_k_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_k_arr
,
top_k_val
,
_get_cuda_stream
(
device
),
)
return
renorm_probs
def
top_k_renorm_probs
(
probs
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
)
->
torch
.
Tensor
:
return
_top_k_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_k
))
top_k_renorm_prob
=
top_k_renorm_probs
def
_top_p_renorm_probs_internal
(
probs
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
renorm_probs
=
torch
.
empty_like
(
probs
)
_top_p_renorm_probs
(
probs
,
renorm_probs
,
maybe_top_p_arr
,
top_p_val
,
_get_cuda_stream
(
device
),
)
return
renorm_probs
def
top_p_renorm_probs
(
probs
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
)
->
torch
.
Tensor
:
return
_top_p_renorm_probs_internal
(
probs
,
*
_to_tensor_scalar_tuple
(
top_p
))
top_p_renorm_prob
=
top_p_renorm_probs
def
_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
_get_cuda_stream
(
device
),
)
return
samples
,
success
def
top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
)
def
_top_k_top_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_top_k_arr
:
Optional
[
torch
.
Tensor
],
top_k_val
:
int
,
maybe_top_p_arr
:
Optional
[
torch
.
Tensor
],
top_p_val
:
float
,
deterministic
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_p_arr
=
(
maybe_top_p_arr
.
float
()
if
maybe_top_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
success
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
bool
,
device
=
device
)
_top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
success
,
maybe_top_k_arr
,
top_k_val
,
maybe_top_p_arr
,
top_p_val
,
deterministic
,
_get_cuda_stream
(
device
),
)
return
samples
,
success
def
top_k_top_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
top_k
:
Union
[
torch
.
Tensor
,
int
],
top_p
:
Union
[
torch
.
Tensor
,
float
],
filter_apply_order
:
str
=
"top_k_first"
,
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
filter_apply_order
==
"top_k_first"
:
renorm_probs
=
top_k_renorm_probs
(
probs
,
top_k
)
return
top_p_sampling_from_probs
(
renorm_probs
,
uniform_samples
,
top_p
,
deterministic
,
check_nan
=
check_nan
)
elif
filter_apply_order
==
"joint"
:
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_top_k_top_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
top_k
),
*
_to_tensor_scalar_tuple
(
top_p
),
deterministic
,
)
else
:
raise
ValueError
(
f
"Invalid filter_apply_order:
{
filter_apply_order
}
"
)
def
_min_p_sampling_from_probs_internal
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
maybe_min_p_arr
:
Optional
[
torch
.
Tensor
],
min_p_val
:
float
,
deterministic
:
bool
,
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
uniform_samples
=
uniform_samples
.
float
()
maybe_min_p_arr
=
(
maybe_min_p_arr
.
float
()
if
maybe_min_p_arr
is
not
None
else
None
)
samples
=
torch
.
empty
(
probs
.
size
(
0
),
dtype
=
torch
.
int32
,
device
=
device
)
_min_p_sampling_from_probs
(
probs
,
uniform_samples
,
samples
,
maybe_min_p_arr
,
min_p_val
,
deterministic
,
_get_cuda_stream
(
device
),
)
return
samples
def
min_p_sampling_from_probs
(
probs
:
torch
.
Tensor
,
uniform_samples
:
torch
.
Tensor
,
min_p
:
Union
[
torch
.
Tensor
,
float
],
deterministic
:
bool
=
True
,
check_nan
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
uniform_samples
.
dim
()
==
2
:
# Take the first row (round) of uniform_samples
uniform_samples
=
uniform_samples
[
0
]
if
check_nan
:
if
torch
.
any
(
torch
.
isnan
(
probs
)):
raise
ValueError
(
"Input probs contains NaN."
)
return
_min_p_sampling_from_probs_internal
(
probs
,
uniform_samples
,
*
_to_tensor_scalar_tuple
(
min_p
),
deterministic
)
sgl-kernel/src/sgl-kernel/ops/utils.py
View file @
5de4051b
...
@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
...
@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
buf
=
torch
.
empty
(
bytes
,
dtype
=
torch
.
uint8
,
device
=
device
)
buf
=
torch
.
empty
(
bytes
,
dtype
=
torch
.
uint8
,
device
=
device
)
_cache_buf
[
key
]
=
buf
_cache_buf
[
key
]
=
buf
return
buf
return
buf
def
_to_tensor_scalar_tuple
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
(
x
,
0
)
else
:
return
(
None
,
x
)
sgl-kernel/tests/test_sampling.py
0 → 100644
View file @
5de4051b
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py
import
pytest
import
sgl_kernel
import
torch
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.1
,
0.5
])
def
test_top_k_top_p_joint_sampling_from_probs
(
batch_size
,
vocab_size
,
p
):
torch
.
manual_seed
(
42
)
if
p
==
0.1
:
k
=
int
(
vocab_size
*
0.5
)
elif
p
==
0.5
:
k
=
int
(
vocab_size
*
0.1
)
else
:
raise
ValueError
(
"p not recognized"
)
max_top_k_trails
=
32
eps
=
1e-4
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# top-p mask
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
cdf
=
torch
.
cumsum
(
sorted_prob
,
dim
=-
1
)
mask_top_p
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
).
to
(
0
)
mask_top_p
.
scatter_add_
(
1
,
indices
,
(
cdf
>
(
1
-
p
)
-
eps
).
int
())
# top-k mask
sorted_prob
,
_
=
torch
.
sort
(
normalized_prob
,
descending
=
True
)
pivot
=
sorted_prob
[:,
k
-
1
]
mask_top_k
=
(
normalized_prob
>=
pivot
.
unsqueeze
(
-
1
)).
int
()
# overall mask
mask
=
torch
.
minimum
(
mask_top_p
,
mask_top_k
)
uniform_samples
=
torch
.
empty
(
max_top_k_trails
,
batch_size
,
dtype
=
torch
.
float32
).
to
(
0
)
top_p_tensor
=
torch
.
full
((
batch_size
,),
p
).
to
(
0
)
top_k_tensor
=
torch
.
full
((
batch_size
,),
k
).
to
(
0
)
num_trails
=
1000
for
_
in
range
(
num_trails
):
uniform_samples
.
uniform_
()
samples
,
success
=
sgl_kernel
.
top_k_top_p_sampling_from_probs
(
normalized_prob
,
uniform_samples
,
top_k_tensor
,
top_p_tensor
,
filter_apply_order
=
"joint"
,
)
assert
torch
.
all
(
success
)
assert
torch
.
all
(
samples
<
vocab_size
)
and
torch
.
all
(
samples
>=
0
)
assert
torch
.
all
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
1
),
normalized_prob
[
torch
.
arange
(
batch_size
),
samples
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.1
,
0.5
,
0.9
])
def
test_top_p_renorm_probs
(
batch_size
,
vocab_size
,
p
):
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
cdf
=
torch
.
cumsum
(
sorted_prob
,
dim
=-
1
)
mask
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
).
to
(
0
)
mask
.
scatter_add_
(
1
,
indices
,
(
cdf
>=
(
1
-
p
)).
int
())
renorm_prob_ground_truth
=
normalized_prob
renorm_prob_ground_truth
[
mask
==
0
]
=
0
renorm_prob_ground_truth
=
renorm_prob_ground_truth
/
renorm_prob_ground_truth
.
sum
(
dim
=-
1
,
keepdim
=
True
)
renorm_prob
=
sgl_kernel
.
top_p_renorm_prob
(
normalized_prob
,
p
)
torch
.
testing
.
assert_close
(
renorm_prob_ground_truth
,
renorm_prob
,
rtol
=
1e-3
,
atol
=
1e-3
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
10
,
100
,
500
])
def
test_top_k_renorm_probs
(
batch_size
,
vocab_size
,
k
):
if
k
>
vocab_size
:
pytest
.
skip
(
"k should be less than vocab_size"
)
torch
.
manual_seed
(
42
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
sorted_prob
,
_
=
torch
.
sort
(
normalized_prob
,
descending
=
True
)
pivot
=
sorted_prob
[:,
k
-
1
]
mask
=
(
normalized_prob
>=
pivot
.
unsqueeze
(
-
1
)).
int
()
renorm_prob_ground_truth
=
normalized_prob
renorm_prob_ground_truth
[
mask
==
0
]
=
0
renorm_prob_ground_truth
=
renorm_prob_ground_truth
/
renorm_prob_ground_truth
.
sum
(
dim
=-
1
,
keepdim
=
True
)
renorm_prob
=
sgl_kernel
.
top_k_renorm_prob
(
normalized_prob
,
k
)
torch
.
testing
.
assert_close
(
renorm_prob_ground_truth
,
renorm_prob
,
rtol
=
1e-3
,
atol
=
1e-3
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
19
,
99
,
989
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
111
,
500
,
32000
,
128256
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.05
,
0.1
,
0.2
,
0.7
,
1
])
def
test_min_p_sampling
(
batch_size
,
vocab_size
,
p
):
torch
.
manual_seed
(
42
)
pre_norm_prob
=
torch
.
rand
(
batch_size
,
vocab_size
).
to
(
0
)
normalized_prob
=
pre_norm_prob
/
pre_norm_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
sorted_prob
,
indices
=
torch
.
sort
(
normalized_prob
,
descending
=
False
)
# scale min-p
top_probs
=
sorted_prob
[:,
-
1
].
unsqueeze
(
-
1
)
scaled_p
=
p
*
top_probs
# min-p mask
mask
=
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
int32
).
to
(
0
)
mask
.
scatter_add_
(
1
,
indices
,
(
sorted_prob
>=
scaled_p
).
int
())
uniform_samples
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
float32
).
to
(
0
)
min_p_tensor
=
torch
.
full
((
batch_size
,),
p
).
to
(
0
)
num_trails
=
1000
for
_
in
range
(
num_trails
):
uniform_samples
.
uniform_
()
samples
=
sgl_kernel
.
min_p_sampling_from_probs
(
normalized_prob
,
uniform_samples
,
min_p_tensor
,
)
assert
torch
.
all
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
1
),
samples
[
torch
.
nonzero
(
mask
[
torch
.
arange
(
batch_size
),
samples
]
==
0
)
]
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment