Unverified Commit fccbfa37 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

format: add clang-format for sgl-kernel (#2483)

parent 2f9bd0fa
BasedOnStyle: Google
IndentWidth: 2
ColumnLimit: 120
AllowShortFunctionsOnASingleLine: Empty
DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
.PHONY: tree ln install build clean test .PHONY: tree ln install build clean test format
tree: tree:
@tree --prune -I "__pycache__|*.egg-info|*.so|build" @tree --prune -I "__pycache__|*.egg-info|*.so|build"
...@@ -17,3 +17,6 @@ clean: ...@@ -17,3 +17,6 @@ clean:
test: test:
@pytest tests/ @pytest tests/
format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
...@@ -2,12 +2,10 @@ ...@@ -2,12 +2,10 @@
torch::Tensor warp_reduce_cuda(torch::Tensor input); torch::Tensor warp_reduce_cuda(torch::Tensor input);
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_INPUT(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") CHECK_CUDA(x); \
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
torch::Tensor warp_reduce(torch::Tensor input) { torch::Tensor warp_reduce(torch::Tensor input) {
......
...@@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) { ...@@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
int lane = threadIdx.x % 32; int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32; int wid = threadIdx.x / 32;
val = warpReduceSum(val); // First reduce within warp val = warpReduceSum(val); // First reduce within warp
if (lane == 0) if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions __syncthreads(); // Wait for all partial reductions
// Read from shared memory only if that warp existed // Read from shared memory only if that warp existed
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0; val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
if (wid == 0) if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp
val = warpReduceSum(val); // Final reduce within first warp
return val; return val;
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void warp_reduce_cuda_kernel( __global__ void warp_reduce_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
input, torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
int N) {
scalar_t sum = 0; scalar_t sum = 0;
// Grid-stride loop // Grid-stride loop
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
i += blockDim.x * gridDim.x) {
sum += input[i]; sum += input[i];
} }
...@@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) { ...@@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) {
// Allocate output tensor for partial sums // Allocate output tensor for partial sums
auto output = torch::empty({blocks}, input.options()); auto output = torch::empty({blocks}, input.options());
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
input.scalar_type(), "warp_reduce_cuda", ([&] { warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>( input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), }));
N);
}));
// Sum the partial results // Sum the partial results
return output.sum(); return output.sum();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment