Unverified Commit fccbfa37 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

format: add clang-format for sgl-kernel (#2483)

parent 2f9bd0fa
BasedOnStyle: Google
IndentWidth: 2
ColumnLimit: 120
AllowShortFunctionsOnASingleLine: Empty
DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
.PHONY: tree ln install build clean test
.PHONY: tree ln install build clean test format
tree:
@tree --prune -I "__pycache__|*.egg-info|*.so|build"
......@@ -17,3 +17,6 @@ clean:
test:
@pytest tests/
format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
......@@ -2,12 +2,10 @@
torch::Tensor warp_reduce_cuda(torch::Tensor input);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor warp_reduce(torch::Tensor input) {
......
......@@ -25,34 +25,28 @@ __device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) {
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
val = warpReduceSum(val); // First reduce within warp
val = warpReduceSum(val); // First reduce within warp
if (lane == 0)
shared[wid] = val; // Write reduced value to shared memory
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
__syncthreads(); // Wait for all partial reductions
// Read from shared memory only if that warp existed
val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0;
if (wid == 0)
val = warpReduceSum(val); // Final reduce within first warp
if (wid == 0) val = warpReduceSum(val); // Final reduce within first warp
return val;
}
template <typename scalar_t>
__global__ void warp_reduce_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits>
input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output,
int N) {
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output, int N) {
scalar_t sum = 0;
// Grid-stride loop
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
sum += input[i];
}
......@@ -84,13 +78,11 @@ torch::Tensor warp_reduce_cuda(torch::Tensor input) {
// Allocate output tensor for partial sums
auto output = torch::empty({blocks}, input.options());
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
N);
}));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "warp_reduce_cuda", ([&] {
warp_reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(), N);
}));
// Sum the partial results
return output.sum();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment