Unverified Commit 53dcc750 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 432f2053
import itertools
import sgl_kernel
import torch
import triton
import triton.testing
def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4
):
"""Reference PyTorch implementation of joint top-k top-p sampling."""
batch_size, vocab_size = normalized_prob.shape
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
for i in range(batch_size):
p_val = top_p[i].item()
k_val = top_k[i].item()
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(
vocab_size, dtype=torch.int32, device=normalized_prob.device
)
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
# top-k mask
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
pivot = sorted_prob_desc[k_val - 1]
mask_top_k = (normalized_prob[i] >= pivot).int()
# joint mask
mask = torch.minimum(mask_top_p, mask_top_k).bool()
# sample from masked probs
masked_probs = normalized_prob[i] * mask
masked_probs = masked_probs / masked_probs.sum()
idx = torch.multinomial(masked_probs, 1)
samples[i] = idx
return samples
def calculate_diff(batch_size, vocab_size, p):
"""Compare Torch reference and SGLang kernel for correctness."""
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
device = torch.device("cuda")
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
top_p_tensor = torch.full((batch_size,), p, device=device)
top_k_tensor = torch.full((batch_size,), k, device=device)
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k_tensor, top_p_tensor
)
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
)
# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "vocab_size", "p"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "sglang"],
line_names=["Torch Reference", "SGL Kernel"],
styles=[("red", "-"), ("green", "-")],
ylabel="us",
plot_name="top-k-top-p-joint-sampling-performance",
args={},
)
)
def benchmark_sampling(batch_size, vocab_size, p, provider):
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
device = torch.device("cuda")
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
top_p_tensor = torch.full((batch_size,), p, device=device)
top_k_tensor = torch.full((batch_size,), k, device=device)
if provider == "torch":
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob.clone(), top_k_tensor, top_p_tensor
)
elif provider == "sglang":
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob.clone(),
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Correctness check
for cfg in configs:
calculate_diff(*cfg)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
benchmark_sampling.run(print_data=True)
...@@ -345,15 +345,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -345,15 +345,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()"); m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
m.def( m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, " "top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); "float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
m.def( m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? " m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
m.def( m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
......
...@@ -593,6 +593,10 @@ void top_p_sampling_from_probs( ...@@ -593,6 +593,10 @@ void top_p_sampling_from_probs(
double top_p_val, double top_p_val,
bool deterministic, bool deterministic,
std::optional<at::Generator> gen); std::optional<at::Generator> gen);
void top_k_mask_logits(
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none, std::optional<torch::Tensor> const& c_or_none,
......
...@@ -85,7 +85,9 @@ from sgl_kernel.moe import ( ...@@ -85,7 +85,9 @@ from sgl_kernel.moe import (
) )
from sgl_kernel.sampling import ( from sgl_kernel.sampling import (
min_p_sampling_from_probs, min_p_sampling_from_probs,
top_k_mask_logits,
top_k_renorm_prob, top_k_renorm_prob,
top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_probs,
top_p_renorm_prob, top_p_renorm_prob,
top_p_sampling_from_probs, top_p_sampling_from_probs,
......
from typing import Optional, Union from typing import Optional, Tuple, Union
import torch import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple from sgl_kernel.utils import _to_tensor_scalar_tuple
...@@ -383,3 +383,161 @@ def min_p_sampling_from_probs( ...@@ -383,3 +383,161 @@ def min_p_sampling_from_probs(
return _min_p_sampling_from_probs_internal( return _min_p_sampling_from_probs_internal(
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
) )
def _top_k_mask_logits_internal(
logits: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
logits = logits.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
mask_logits = torch.empty_like(logits)
torch.ops.sgl_kernel.top_k_mask_logits.default(
logits, mask_logits, maybe_top_k_arr, top_k_val
)
return mask_logits
def top_k_mask_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for masking logits by top-k thresholding.
Parameters
----------
logits: torch.Tensor
Logits before softmax, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for masking logits, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k logits, set the rest to negative infinity.
Returns
-------
masked_logits: torch.Tensor
Masked logits, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> logits = torch.randn(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
>>> masked_logits
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
Note
----
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
See Also
--------
top_k_renorm_probs
"""
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
def top_k_top_p_sampling_from_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
logits: torch.Tensor
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if filter_apply_order == "top_k_first":
masked_logits = top_k_mask_logits(logits, top_k)
probs = torch.softmax(masked_logits, dim=-1)
return top_p_sampling_from_probs(
probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
probs = torch.softmax(logits, dim=-1)
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
...@@ -5,6 +5,54 @@ import sgl_kernel ...@@ -5,6 +5,54 @@ import sgl_kernel
import torch import torch
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [100])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment(
batch_size, vocab_size, k, p
):
torch.manual_seed(42)
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
generator_logits = torch.Generator("cuda:0")
generator_probs = generator_logits.clone_state()
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
)
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
torch.softmax(logits, dim=-1),
k,
p,
filter_apply_order="top_k_first",
generator=generator_probs,
)
assert torch.all(samples == samples_ref)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [100])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling_from_probs_logits_joint_alignment(
batch_size, vocab_size, k, p
):
torch.manual_seed(42)
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
generator_logits = torch.Generator("cuda:0")
generator_probs = generator_logits.clone_state()
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, filter_apply_order="joint", generator=generator_logits
)
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
torch.softmax(logits, dim=-1),
k,
p,
filter_apply_order="joint",
generator=generator_probs,
)
assert torch.all(samples == samples_ref)
@pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5]) @pytest.mark.parametrize("p", [0.1, 0.5])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment