Commit 5f051c97 authored by Max Rietmann's avatar Max Rietmann
Browse files

Optimized CUDA kernels for improved backward gradient computation



Introduce new CUDA kernels, `s2_attention_bwd_dkvq_kernel_mbT` and
`s2_attention_kernel_mbT`, for more efficient computation of backward gradients
and forward attention respectively. These changes optimize memory access
patterns and employ coalesced operations by leveraging tensor transpositions.

Forward kernel written by Mauro Bisson
Backwards kernel written by Andrea Paris (aparis@ethz.ch) and Max Rietmann

Parallelization strategy computes 1 output per Warp, with threads computing the
dot-product in parallel. Because inputs are transposed to have channel dimension
last, the dot-product memory access pattern is perfectly coalesced, leading to
excellent performance. This is true across both forward and backward kernels.
Co-authored-by: default avatarMauro Bisson <maurob@nvidia.com>
Co-authored-by: default avatarMax Rietmann <mrietmann@nvidia.com>
Co-authored-by: default avatarAndrea Paris <aparis@ethz.ch>
parent 318fc76e
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -34,12 +34,37 @@ ...@@ -34,12 +34,37 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
__device__ static float atomicMax(float* address, float val) __device__ static float atomicMax(float* address, float val)
{ {
int* address_as_i = (int*) address; int* address_as_i = (int*) address;
...@@ -70,9 +95,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -70,9 +95,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
float* sh_qy_ho_wo = (float *)&sharedMem[2]; float* sh_qy_ho_wo = (float *)&sharedMem[2];
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest(); sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0; sh_alpha_sum[0] = 0.0;
} }
__syncthreads(); __syncthreads();
int ho = blockIdx.x; int ho = blockIdx.x;
...@@ -171,10 +196,105 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -171,10 +196,105 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
} }
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel_mbT(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float *shy = sh + threadIdx.y*num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out);
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
#else
shy[chan] = 0;
#endif
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][ wo][chan]*
kx[batchId][hi][wip][chan];
}
qdotk = __warp_sum(qdotk);
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][hi][wip][chan]*alpha;
}
qdotk_max = qdotk_max_tmp;
}
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
}
return;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor quad_weights, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
...@@ -193,36 +313,44 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -193,36 +313,44 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations // transpose inputs so that channels are in the last dimension, allowing for
assert(nlon_out < 65535); // coalesced memory access
assert(batch_size < 65535); torch::Tensor kxP = kx.permute({0,2,3,1}).contiguous();
torch::Tensor vxP = vx.permute({0,2,3,1}).contiguous();
// block-parallel over output points and batches torch::Tensor qyP = qy.permute({0,2,3,1}).contiguous();
dim3 gridDim(nlat_out,nlon_out,batch_size); torch::Tensor y = torch::empty_like(qy);
// threads compute "blocks" of neighborhood and also "blocks" of channels dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
// note: blocksize of 512 runs into resource limits dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
dim3 blockDim(256,1,1);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, cudaEvent_t start, stop;
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), float milliseconds = 0;
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventCreate(&start));
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventCreate(&stop));
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventRecord(start, stream));
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), s2_attention_kernel_mbT<THREADS>
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() <<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
); kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment