Unverified Commit a53454c5 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: sgl-kernel link cuda (#2906)

parent 6cb3974e
...@@ -11,6 +11,8 @@ docker run --rm \ ...@@ -11,6 +11,8 @@ docker run --rm \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \ export CUDA_VERSION=${CUDA_VERSION} && \
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
cd /sgl-kernel && \ cd /sgl-kernel && \
${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel
" "
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sgl-kernel" name = "sgl-kernel"
version = "0.0.2.post13" version = "0.0.2.post14"
description = "Kernel Library for SGLang" description = "Kernel Library for SGLang"
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
......
...@@ -41,7 +41,7 @@ nvcc_flags = [ ...@@ -41,7 +41,7 @@ nvcc_flags = [
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"] libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [ ext_modules = [
CUDAExtension( CUDAExtension(
name="sgl_kernel.ops._kernels", name="sgl_kernel.ops._kernels",
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
#include "utils.hpp" #include "utils.hpp"
#include "vectorization.cuh" #include "vectorization.cuh"
template <typename scalar_t> template <typename scalar_t>
__global__ void sampling_scaling_penalties_kernel( __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
const scalar_t* logits, scalar_t* output, const int32_t numel) {
const scalar_t* scaling_penalties, const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
scalar_t* output, const int32_t stride = blockDim.x * gridDim.x;
const int32_t numel) {
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t stride = blockDim.x * gridDim.x;
auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits); auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits);
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties); auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties);
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output); auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output);
const int32_t num_vec_elems = numel >> 2; const int32_t num_vec_elems = numel >> 2;
#pragma unroll 4 #pragma unroll 4
for (int32_t i = tid; i < num_vec_elems; i += stride) { for (int32_t i = tid; i < num_vec_elems; i += stride) {
vec4_t<scalar_t> logits_vec = vectorized_logits[i]; vec4_t<scalar_t> logits_vec = vectorized_logits[i];
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i]; vec4_t<scalar_t> penalties_vec = vectorized_penalties[i];
vec4_t<scalar_t> out_vec; vec4_t<scalar_t> out_vec;
out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;
vectorized_output[i] = out_vec; vectorized_output[i] = out_vec;
} }
const int32_t start_idx = num_vec_elems * 4; const int32_t start_idx = num_vec_elems * 4;
for (int32_t i = start_idx + tid; i < numel; i += stride) { for (int32_t i = start_idx + tid; i < numel; i += stride) {
scalar_t logit = logits[i]; scalar_t logit = logits[i];
scalar_t penalty = scaling_penalties[i]; scalar_t penalty = scaling_penalties[i];
output[i] = logit > 0 ? logit / penalty : logit * penalty; output[i] = logit > 0 ? logit / penalty : logit * penalty;
} }
} }
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) { torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) {
auto output = torch::empty_like(logits); auto output = torch::empty_like(logits);
const auto numel = logits.numel(); const auto numel = logits.numel();
const int threads = 512; const int threads = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, AT_DISPATCH_FLOATING_TYPES_AND2(
logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
const int blocks = (numel + threads * 4 - 1) / (threads * 4); const int blocks = (numel + threads * 4 - 1) / (threads * 4);
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>( sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
logits.data_ptr<scalar_t>(), logits.data_ptr<scalar_t>(), scaling_penalties.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), numel);
scaling_penalties.data_ptr<scalar_t>(), }));
output.data_ptr<scalar_t>(),
numel);
}));
return output; return output;
} }
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
// Include both AMD and NVIDIA fp8 types to avoid circular import // Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring // TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h> #include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
// Vectorization containers // Vectorization containers
template <typename scalar_t> template <typename scalar_t>
...@@ -20,8 +20,7 @@ struct __align__(8) vec4_t { ...@@ -20,8 +20,7 @@ struct __align__(8) vec4_t {
template <typename quant_type_t> template <typename quant_type_t>
struct __align__(4) q8x4_t { struct __align__(4) q8x4_t {
static_assert(std::is_same_v<quant_type_t, int8_t> || static_assert(std::is_same_v<quant_type_t, int8_t> || std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>); std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
quant_type_t x; quant_type_t x;
quant_type_t y; quant_type_t y;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment