"vscode:/vscode.git/clone" did not exist on "ff182ad6694ada3c01b3514eeae03392b2761b92"
Unverified Commit 5d9d15e7 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

support fp32 in sampling_scaling_penalties kernel (#3121)

parent 665e5e85
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <pytorch_extension_utils.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
......@@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] {
uint32_t vec_size = 16 / sizeof(scalar_t);
const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
......
#pragma once
#include <pytorch_extension_utils.h>
#include <torch/extension.h>
#include <sstream>
......@@ -44,3 +45,20 @@ inline int getSMVersion() {
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
......@@ -2,10 +2,14 @@ import pytest
import torch
from sgl_kernel import sampling_scaling_penalties
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
dtypes = [torch.float32, torch.half, torch.bfloat16]
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65])
@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("batch_size", batch_sizes)
@pytest.mark.parametrize("vocab_size", vocab_sizes)
@pytest.mark.parametrize("dtype", dtypes)
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
device = torch.device("cuda")
rtol = 1e-3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment