Unverified Commit 5de4051b authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: integrate sampling kernels into sgl-kernel (#3086)


Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent e0cd65c2
...@@ -128,6 +128,7 @@ ext_modules = [ ...@@ -128,6 +128,7 @@ ext_modules = [
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu", "3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
"3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/norm.cu",
"3rdparty/flashinfer/csrc/sampling.cu", "3rdparty/flashinfer/csrc/sampling.cu",
"3rdparty/flashinfer/csrc/renorm.cu",
], ],
include_dirs=include_dirs, include_dirs=include_dirs,
extra_compile_args={ extra_compile_args={
......
...@@ -11,12 +11,16 @@ from sgl_kernel.ops import ( ...@@ -11,12 +11,16 @@ from sgl_kernel.ops import (
init_custom_reduce, init_custom_reduce,
int8_scaled_mm, int8_scaled_mm,
lightning_attention_decode, lightning_attention_decode,
min_p_sampling_from_probs,
moe_align_block_size, moe_align_block_size,
register_graph_buffers, register_graph_buffers,
rmsnorm, rmsnorm,
rotary_embedding, rotary_embedding,
sampling_scaling_penalties, sampling_scaling_penalties,
silu_and_mul, silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
) )
__all__ = [ __all__ = [
...@@ -31,11 +35,15 @@ __all__ = [ ...@@ -31,11 +35,15 @@ __all__ = [
"get_graph_buffer_ipc_meta", "get_graph_buffer_ipc_meta",
"init_custom_reduce", "init_custom_reduce",
"int8_scaled_mm", "int8_scaled_mm",
"lightning_attention_decode",
"min_p_sampling_from_probs",
"moe_align_block_size", "moe_align_block_size",
"register_graph_buffers", "register_graph_buffers",
"rmsnorm", "rmsnorm",
"rotary_embedding", "rotary_embedding",
"sampling_scaling_penalties", "sampling_scaling_penalties",
"lightning_attention_decode",
"silu_and_mul", "silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
] ]
...@@ -61,6 +61,30 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); ...@@ -61,6 +61,30 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
// min p sampling from probs
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
int64_t cuda_stream);
// top k renorm probs
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val, int64_t cuda_stream);
// top p renorm probs
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val, int64_t cuda_stream);
// top k top p sampling from probs
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream);
// top p sampling from probs
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce // trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
...@@ -94,4 +118,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -94,4 +118,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
// bmm fp8 // bmm fp8
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)"); m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
// min p sampling from probs
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)");
// top k renorm probs
m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)");
// top p renorm probs
m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)");
// top k top p sampling from probs
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)");
// top p sampling from probs
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)");
} }
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce from sgl_kernel.ops._kernels import all_reduce as _all_reduce
...@@ -17,6 +17,9 @@ from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm ...@@ -17,6 +17,9 @@ from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import ( from sgl_kernel.ops._kernels import (
lightning_attention_decode as _lightning_attention_decode, lightning_attention_decode as _lightning_attention_decode,
) )
from sgl_kernel.ops._kernels import (
min_p_sampling_from_probs as _min_p_sampling_from_probs,
)
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
...@@ -25,7 +28,19 @@ from sgl_kernel.ops._kernels import ( ...@@ -25,7 +28,19 @@ from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties, sampling_scaling_penalties as _sampling_scaling_penalties,
) )
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs
from sgl_kernel.ops._kernels import (
top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs,
)
from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs
from sgl_kernel.ops._kernels import (
top_p_sampling_from_probs as _top_p_sampling_from_probs,
)
from sgl_kernel.ops.utils import (
_get_cache_buf,
_get_cuda_stream,
_to_tensor_scalar_tuple,
)
def init_custom_reduce( def init_custom_reduce(
...@@ -236,3 +251,213 @@ def bmm_fp8( ...@@ -236,3 +251,213 @@ def bmm_fp8(
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out return out
def _top_k_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
_top_k_renorm_probs(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
_get_cuda_stream(device),
)
return renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
top_k_renorm_prob = top_k_renorm_probs
def _top_p_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
renorm_probs = torch.empty_like(probs)
_top_p_renorm_probs(
probs,
renorm_probs,
maybe_top_p_arr,
top_p_val,
_get_cuda_stream(device),
)
return renorm_probs
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
top_p_renorm_prob = top_p_renorm_probs
def _top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
_top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_top_p_arr,
top_p_val,
deterministic,
_get_cuda_stream(device),
)
return samples, success
def top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
)
def _top_k_top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
_top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
_get_cuda_stream(device),
)
return samples, success
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
)
elif filter_apply_order == "joint":
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
uniform_samples,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def _min_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
_min_p_sampling_from_probs(
probs,
uniform_samples,
samples,
maybe_min_p_arr,
min_p_val,
deterministic,
_get_cuda_stream(device),
)
return samples
def min_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
min_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> torch.Tensor:
if uniform_samples.dim() == 2:
# Take the first row (round) of uniform_samples
uniform_samples = uniform_samples[0]
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _min_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
)
...@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: ...@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
buf = torch.empty(bytes, dtype=torch.uint8, device=device) buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf _cache_buf[key] = buf
return buf return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py
import pytest
import sgl_kernel
import torch
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
max_top_k_trails = 32
eps = 1e-4
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
# top-k mask
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
# overall mask
mask = torch.minimum(mask_top_p, mask_top_k)
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
0
)
top_p_tensor = torch.full((batch_size,), p).to(0)
top_k_tensor = torch.full((batch_size,), k).to(0)
num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples, success = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob,
uniform_samples,
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
assert torch.all(success)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
torch.arange(batch_size), samples
]
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
def test_top_p_renorm_probs(batch_size, vocab_size, p):
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
renorm_prob_ground_truth = normalized_prob
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p)
torch.testing.assert_close(
renorm_prob_ground_truth,
renorm_prob,
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("k", [10, 100, 500])
def test_top_k_renorm_probs(batch_size, vocab_size, k):
if k > vocab_size:
pytest.skip("k should be less than vocab_size")
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
renorm_prob_ground_truth = normalized_prob
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
torch.testing.assert_close(
renorm_prob_ground_truth,
renorm_prob,
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
def test_min_p_sampling(batch_size, vocab_size, p):
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
# scale min-p
top_probs = sorted_prob[:, -1].unsqueeze(-1)
scaled_p = p * top_probs
# min-p mask
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0)
min_p_tensor = torch.full((batch_size,), p).to(0)
num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples = sgl_kernel.min_p_sampling_from_probs(
normalized_prob,
uniform_samples,
min_p_tensor,
)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
]
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment