"vscode:/vscode.git/clone" did not exist on "64b0ae3041b05dc6c9c42484dbad01c3b6bc5cd1"
Unverified Commit c08a717c authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[Feat] Update sgl-kernel flashinfer to latest main version (#5500)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent f13d65a7
......@@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
GIT_TAG sgl-kernel
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
......
......@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/elementwise
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
......@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()");
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
m.def(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
/*
......
......@@ -21,7 +21,8 @@ limitations under the License.
using namespace flashinfer;
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
......@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
static_cast<c_type*>(weight.data_ptr()),
batch_size,
hidden_size,
input.stride(0),
residual.stride(0),
eps,
enable_pdl,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
......
......@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
DType threshold_acc) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
DType prob_acc = 0.0;
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
......@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
if (aggregate_relu_q_minus_p > u) {
break;
......@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
......
......@@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
/*
* From csrc/elementwise
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
......@@ -254,48 +254,38 @@ void segment_packbits(
*/
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);
void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream);
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
int64_t cuda_stream);
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr,
double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);
void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);
namespace flash {
/*
......
......@@ -11,17 +11,69 @@ def rmsnorm(
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> torch.Tensor:
r"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: bool = False,
) -> None:
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
r"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def gemma_rmsnorm(
......@@ -29,20 +81,68 @@ def gemma_rmsnorm(
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gemma_rmsnorm.default(
out, input, weight, eps, get_cuda_stream()
)
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: bool = False,
) -> None:
r"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, get_cuda_stream()
input, residual, weight, eps, enable_pdl
)
......
......@@ -13,11 +13,7 @@ def _top_k_renorm_probs_internal(
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs.default(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
get_cuda_stream(),
probs, renorm_probs, maybe_top_k_arr, top_k_val
)
return renorm_probs
......@@ -26,6 +22,30 @@ def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for re-normalizing probabilities, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
"""
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
......@@ -41,11 +61,7 @@ def _top_p_renorm_probs_internal(
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs.default(
probs,
renorm_probs,
maybe_top_p_arr,
top_p_val,
get_cuda_stream(),
probs, renorm_probs, maybe_top_p_arr, top_p_val
)
return renorm_probs
......@@ -54,6 +70,32 @@ def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
re-normalizing probabilities, should be in ``(0, 1)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
"""
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
......@@ -62,93 +104,187 @@ top_p_renorm_prob = top_p_renorm_probs
def _top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
probs,
uniform_samples,
samples,
success,
indices,
maybe_top_p_arr,
top_p_val,
deterministic,
get_cuda_stream(),
generator,
)
return samples, success
return samples
def top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
)
def _top_k_top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
probs,
uniform_samples,
samples,
success,
indices,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
get_cuda_stream(),
generator,
)
return samples, success
return samples
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
renorm_probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
if check_nan:
......@@ -156,10 +292,11 @@ def top_k_top_p_sampling_from_probs(
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
uniform_samples,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
......@@ -167,44 +304,82 @@ def top_k_top_p_sampling_from_probs(
def _min_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
probs,
uniform_samples,
samples,
indices,
maybe_min_p_arr,
min_p_val,
deterministic,
get_cuda_stream(),
generator,
)
return samples
def min_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
min_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
if uniform_samples.dim() == 2:
# Take the first row (round) of uniform_samples
uniform_samples = uniform_samples[0]
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
min_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _min_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
)
......@@ -5,8 +5,8 @@ import sgl_kernel
import torch
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
torch.manual_seed(42)
......@@ -16,14 +16,13 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
max_top_k_trails = 32
eps = 1e-4
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
# top-k mask
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
......@@ -31,40 +30,35 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
# overall mask
mask = torch.minimum(mask_top_p, mask_top_k)
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
0
)
top_p_tensor = torch.full((batch_size,), p).to(0)
top_k_tensor = torch.full((batch_size,), k).to(0)
top_p_tensor = torch.full((batch_size,), p, device="cuda:0")
top_k_tensor = torch.full((batch_size,), k, device="cuda:0")
num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples, success = sgl_kernel.top_k_top_p_sampling_from_probs(
samples = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob,
uniform_samples,
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
assert torch.all(success)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
torch.arange(batch_size), samples
]
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
def test_top_p_renorm_probs(batch_size, vocab_size, p):
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
renorm_prob_ground_truth = normalized_prob
renorm_prob_ground_truth = normalized_prob.clone()
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
......@@ -79,56 +73,54 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p):
)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [10, 100, 500])
def test_top_k_renorm_probs(batch_size, vocab_size, k):
if k > vocab_size:
pytest.skip("k should be less than vocab_size")
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
renorm_prob_ground_truth = normalized_prob
renorm_prob_ground_truth = normalized_prob.clone()
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
torch.testing.assert_close(
renorm_prob_ground_truth,
renorm_prob,
rtol=1e-3,
atol=1e-3,
)
for i in range(batch_size):
torch.testing.assert_close(
renorm_prob_ground_truth[i],
renorm_prob[i],
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
def test_min_p_sampling(batch_size, vocab_size, p):
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
# scale min-p
top_probs = sorted_prob[:, -1].unsqueeze(-1)
scaled_p = p * top_probs
# min-p mask
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0)
min_p_tensor = torch.full((batch_size,), p).to(0)
min_p_tensor = torch.full((batch_size,), p, device="cuda:0")
num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples = sgl_kernel.min_p_sampling_from_probs(
normalized_prob,
uniform_samples,
min_p_tensor,
)
......@@ -136,6 +128,10 @@ def test_min_p_sampling(batch_size, vocab_size, p):
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
]
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
]
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment