Unverified Commit 47eb139f authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: use warp reduce as a simple example (#2304)

parent 5c18a037
......@@ -185,3 +185,36 @@ work_dirs/
*.csv
!logo.png
# Prerequisites
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
[build-system]
requires = ["setuptools>=61.0", "wheel"]
requires = ["setuptools>=61.0", "wheel", "torch"]
build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
version = "0.0.1"
version = "0.0.2"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: C++",
"Programming Language :: CUDA",
]
dependencies = [
"torch",
]
dependencies = ["numpy"]
[project.optional-dependencies]
srt = ["torch"]
all = ["sgl-kernel[srt]"]
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.packages.find]
exclude = [
"dist*",
"tests*",
]
[tool.setuptools]
package-dir = {"sgl_kernel" = "src/sgl-kernel"}
packages = ["sgl_kernel", "sgl_kernel.ops", "sgl_kernel.csrc"]
[tool.wheel]
exclude = [
"dist*",
"tests*",
"dist*",
"tests*",
]
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="sgl-kernel",
version="0.0.2",
packages=find_packages(where="src"),
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.warp_reduce_cuda",
[
"src/sgl-kernel/csrc/warp_reduce.cc",
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
],
)
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
from .ops import warp_reduce
__all__ = ["warp_reduce"]
#include <torch/extension.h>
#include <vector>
torch::Tensor warp_reduce_cuda(torch::Tensor input);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor warp_reduce(torch::Tensor input) {
CHECK_INPUT(input);
return warp_reduce_cuda(input);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)");
}
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#define FINAL_MASK 0xffffffff
#define BLOCK_SIZE 256
template <typename scalar_t>
__device__ __forceinline__ scalar_t add(scalar_t a, scalar_t b) {
return a + b;
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(FINAL_MASK, val, offset);
}
return val;
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
__shared__ scalar_t shared[32];
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
val = warpReduceSum(val); // First reduce within warp
if (lane == 0)
shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// Read from shared memory only if that warp existed
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
if (wid == 0)
val = warpReduceSum(val); // Final reduce within first warp
return val;
}
template <typename scalar_t>
__global__ void warp_reduce_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
int N) {
scalar_t sum = 0;
// Grid-stride loop
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum += input[i];
}
// Perform block-wide reduction
sum = blockReduceSum(sum);
// Write result for this block to global memory
if (threadIdx.x == 0) {
output[blockIdx.x] = sum;
}
}
torch::Tensor warp_reduce_cuda(torch::Tensor input) {
// Input validation
TORCH_CHECK(input.dim() == 1, "1D tensor expected");
TORCH_CHECK(input.is_cuda(), "CUDA tensor expected");
const auto N = input.size(0);
// Handle empty tensor
if (N == 0) {
return torch::zeros({1}, input.options());
}
// Calculate grid dimensions
const int threads = BLOCK_SIZE;
const int blocks = (N + threads - 1) / threads;
// Allocate output tensor for partial sums
auto output = torch::empty({blocks}, input.options());
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
N);
}));
// Sum the partial results
return output.sum();
}
from .warp_reduce_cuda import reduce as _reduce
def warp_reduce(input_tensor):
return _reduce(input_tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment