Commit 5f051c97 authored by Max Rietmann's avatar Max Rietmann
Browse files

Optimized CUDA kernels for improved backward gradient computation



Introduce new CUDA kernels, `s2_attention_bwd_dkvq_kernel_mbT` and
`s2_attention_kernel_mbT`, for more efficient computation of backward gradients
and forward attention respectively. These changes optimize memory access
patterns and employ coalesced operations by leveraging tensor transpositions.

Forward kernel written by Mauro Bisson
Backwards kernel written by Andrea Paris (aparis@ethz.ch) and Max Rietmann

Parallelization strategy computes 1 output per Warp, with threads computing the
dot-product in parallel. Because inputs are transposed to have channel dimension
last, the dot-product memory access pattern is perfectly coalesced, leading to
excellent performance. This is true across both forward and backward kernels.
Co-authored-by: default avatarMauro Bisson <maurob@nvidia.com>
Co-authored-by: default avatarMax Rietmann <mrietmann@nvidia.com>
Co-authored-by: default avatarAndrea Paris <aparis@ethz.ch>
parent 318fc76e
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -39,8 +39,27 @@ ...@@ -39,8 +39,27 @@
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; #ifndef WARP_SIZE
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; #define WARP_SIZE (32)
#endif
#ifndef FULL_MASK
#define FULL_MASK (0xFFFFFFFF)
#endif
#ifndef THREADS
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#endif
__device__ static float atomicMax(float* address, float val) __device__ static float atomicMax(float* address, float val)
{ {
...@@ -54,6 +73,27 @@ __device__ static float atomicMax(float* address, float val) ...@@ -54,6 +73,27 @@ __device__ static float atomicMax(float* address, float val)
return __int_as_float(old); return __int_as_float(old);
} }
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
__device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
__global__ void __global__ void
s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
...@@ -65,7 +105,7 @@ s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon ...@@ -65,7 +105,7 @@ s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{ {
// shared memory // shared memory
extern __shared__ float sharedMem[]; extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem; // 1 float* sh_alpha_sum = (float*)&sharedMem; // 1
...@@ -223,16 +263,16 @@ at::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx, ...@@ -223,16 +263,16 @@ at::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
dim3 blockDim(256, 1, 1); dim3 blockDim(256, 1, 1);
s2_attention_bwd_dv_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>( s2_attention_bwd_dv_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
); );
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
...@@ -251,7 +291,7 @@ s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon ...@@ -251,7 +291,7 @@ s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{ {
// shared memory // shared memory
extern __shared__ float sharedMem[]; extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem; float* sh_alpha_sum = (float*)&sharedMem;
...@@ -386,7 +426,7 @@ s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon ...@@ -386,7 +426,7 @@ s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon
__global__ void __global__ void
s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
...@@ -395,7 +435,7 @@ s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon ...@@ -395,7 +435,7 @@ s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{ {
// shared memory // shared memory
extern __shared__ float sharedMem[]; extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem; float* sh_alpha_sum = (float*)&sharedMem;
...@@ -666,16 +706,136 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int ...@@ -666,16 +706,136 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip], atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) * sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) *
(gdotv - integral)); (gdotv - integral));
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * sh_dy_ho_wo[channel_idx]); atomicAdd(&dydv[batch_b][channel_idx][hi][wip],
(alpha_inz / alpha_sum) * sh_dy_ho_wo[channel_idx]);
} }
} }
__syncthreads(); __syncthreads();
} }
// New kernel: s2_attention_bwd_dkvq_kernel_mbT
// This kernel assumes kx, vx, qy, dy, dydk, dydv, dydq are all [batch, ho, wo, channel] (transposed)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel_mbT(
int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 4;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float* sh_dy = sh_alpha_kvw + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
// Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][ho][wo][chan];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend - rbeg;
// First pass: find qdotk_max
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
gdotv += sh_dy[chan] * vx[batchId][hi][wip][chan];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][hi][wip][chan];
sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
}
}
integral /= alpha_sum;
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][ho][wo][chan] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
gdotv += sh_dy[chan] * vx[batchId][hi][wip][chan];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][ho][wo][chan];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][hi][wip][chan], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][hi][wip][chan], (alpha_inz / alpha_sum) * dyval);
}
}
}
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx, at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor dy, at::Tensor dy,
...@@ -713,16 +873,16 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx, ...@@ -713,16 +873,16 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
dim3 blockDim(256, 1, 1); dim3 blockDim(256, 1, 1);
s2_attention_bwd_dk_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>( s2_attention_bwd_dk_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
); );
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
...@@ -731,7 +891,7 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx, ...@@ -731,7 +891,7 @@ at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
} }
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx, at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor dy, at::Tensor dy,
...@@ -769,16 +929,16 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx, ...@@ -769,16 +929,16 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
dim3 blockDim(256, 1, 1); dim3 blockDim(256, 1, 1);
s2_attention_bwd_dq_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>( s2_attention_bwd_dq_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
); );
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
...@@ -787,12 +947,12 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx, ...@@ -787,12 +947,12 @@ at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
} }
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor dy, at::Tensor dy,
at::Tensor quad_weights, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) { int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
...@@ -804,43 +964,137 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -804,43 +964,137 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(qy);
torch::Tensor dydv = torch::zeros_like(qy);
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations // enum for which kernel version
assert(nlon_out < 65535); enum KERNEL_VERSION {
assert(batch_size < 65535); OLD_VERSION = 0,
HOWO_WARP_VERSION = 2,
// block-parallel over output points and batches };
dim3 gridDim(nlat_out,nlon_out,batch_size); auto version = HOWO_WARP_VERSION;
// auto version = OLD_VERSION;
// threads compute "blocks" of neighborhood and also "blocks" of channels if (version == OLD_VERSION) {
dim3 blockDim(256, 1, 1); printf("old version\n");
torch::Tensor dydk = torch::zeros_like(qy);
s2_attention_bwd_dkvq_kernel<<<gridDim, blockDim, sharedMemSize, stream>>>( torch::Tensor dydv = torch::zeros_like(qy);
uo_num_channels, nlon_in, nlat_out, nlon_out, torch::Tensor dydq = torch::zeros_like(qy);
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), // cuda grid y,z size limitations
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), assert(nlon_out < 65535);
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), assert(batch_size < 65535);
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), // block-parallel over output points and batches
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), dim3 gridDim(nlat_out,nlon_out,batch_size);
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
); // threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
C10_CUDA_KERNEL_LAUNCH_CHECK(); // Define CUDA event variables for timing
cudaEvent_t start_event, stop_event;
return std::make_tuple(dydk, dydv, dydq); cudaEventCreate(&start_event);
cudaEventCreate(&stop_event);
// Record the start event
cudaEventRecord(start_event, stream);
s2_attention_bwd_dkvq_kernel<<<
gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
// Record the stop event
cudaEventRecord(stop_event, stream);
cudaEventSynchronize(stop_event);
// Calculate elapsed time
float kernel_time_ms;
cudaEventElapsedTime(&kernel_time_ms, start_event, stop_event);
// Output the result
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// Old bwd kernel execution time: 803.477 ms
// std::cout << "Old bwd kernel execution time: " << kernel_time_ms << " ms" << std::endl;
// Cleanup events
cudaEventDestroy(start_event);
cudaEventDestroy(stop_event);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(dydk, dydv, dydq);
} else if (version == HOWO_WARP_VERSION) {
// Transpose to [batch, ho, wo, channel]
auto kxP = kx.permute({0,2,3,1}).contiguous();
auto vxP = vx.permute({0,2,3,1}).contiguous();
auto qyP = qy.permute({0,2,3,1}).contiguous();
auto dyP = dy.permute({0,2,3,1}).contiguous();
auto dydkP = torch::zeros_like(qyP);
auto dydvP = torch::zeros_like(qyP);
auto dydqP = torch::zeros_like(qyP);
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 4 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel_mbT<THREADS><<<
grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydkP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydvP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydqP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to [batch, channel, ho, wo]
auto dydk = dydkP.permute({0,3,1,2}).contiguous();
auto dydv = dydvP.permute({0,3,1,2}).contiguous();
auto dydq = dydqP.permute({0,3,1,2}).contiguous();
return std::make_tuple(dydk, dydv, dydq);
} else {
throw std::runtime_error("Invalid kernel version specified");
}
} }
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -34,12 +34,37 @@ ...@@ -34,12 +34,37 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
__device__ static float atomicMax(float* address, float val) __device__ static float atomicMax(float* address, float val)
{ {
int* address_as_i = (int*) address; int* address_as_i = (int*) address;
...@@ -70,9 +95,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -70,9 +95,9 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
float* sh_qy_ho_wo = (float *)&sharedMem[2]; float* sh_qy_ho_wo = (float *)&sharedMem[2];
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest(); sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0; sh_alpha_sum[0] = 0.0;
} }
__syncthreads(); __syncthreads();
int ho = blockIdx.x; int ho = blockIdx.x;
...@@ -171,10 +196,105 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -171,10 +196,105 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
} }
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel_mbT(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float *shy = sh + threadIdx.y*num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out);
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
#else
shy[chan] = 0;
#endif
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][ wo][chan]*
kx[batchId][hi][wip][chan];
}
qdotk = __warp_sum(qdotk);
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][hi][wip][chan]*alpha;
}
qdotk_max = qdotk_max_tmp;
}
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
}
return;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
at::Tensor quad_weights, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
...@@ -193,36 +313,44 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -193,36 +313,44 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations // transpose inputs so that channels are in the last dimension, allowing for
assert(nlon_out < 65535); // coalesced memory access
assert(batch_size < 65535); torch::Tensor kxP = kx.permute({0,2,3,1}).contiguous();
torch::Tensor vxP = vx.permute({0,2,3,1}).contiguous();
// block-parallel over output points and batches torch::Tensor qyP = qy.permute({0,2,3,1}).contiguous();
dim3 gridDim(nlat_out,nlon_out,batch_size); torch::Tensor y = torch::empty_like(qy);
// threads compute "blocks" of neighborhood and also "blocks" of channels dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
// note: blocksize of 512 runs into resource limits dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
dim3 blockDim(256,1,1);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, cudaEvent_t start, stop;
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), float milliseconds = 0;
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventCreate(&start));
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventCreate(&stop));
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventRecord(start, stream));
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), s2_attention_kernel_mbT<THREADS>
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() <<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
); kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment