Commit 5f051c97 authored by Max Rietmann's avatar Max Rietmann
Browse files

Optimized CUDA kernels for improved backward gradient computation



Introduce new CUDA kernels, `s2_attention_bwd_dkvq_kernel_mbT` and
`s2_attention_kernel_mbT`, for more efficient computation of backward gradients
and forward attention respectively. These changes optimize memory access
patterns and employ coalesced operations by leveraging tensor transpositions.

Forward kernel written by Mauro Bisson
Backwards kernel written by Andrea Paris (aparis@ethz.ch) and Max Rietmann

Parallelization strategy computes 1 output per Warp, with threads computing the
dot-product in parallel. Because inputs are transposed to have channel dimension
last, the dot-product memory access pattern is perfectly coalesced, leading to
excellent performance. This is true across both forward and backward kernels.
Co-authored-by: default avatarMauro Bisson <maurob@nvidia.com>
Co-authored-by: default avatarMax Rietmann <mrietmann@nvidia.com>
Co-authored-by: default avatarAndrea Paris <aparis@ethz.ch>
parent 318fc76e
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
...@@ -39,8 +39,27 @@ ...@@ -39,8 +39,27 @@
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; #ifndef WARP_SIZE
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; #define WARP_SIZE (32)
#endif
#ifndef FULL_MASK
#define FULL_MASK (0xFFFFFFFF)
#endif
#ifndef THREADS
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#endif
__device__ static float atomicMax(float* address, float val) __device__ static float atomicMax(float* address, float val)
{ {
...@@ -54,6 +73,27 @@ __device__ static float atomicMax(float* address, float val) ...@@ -54,6 +73,27 @@ __device__ static float atomicMax(float* address, float val)
return __int_as_float(old); return __int_as_float(old);
} }
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
__device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
__global__ void __global__ void
s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
...@@ -667,13 +707,133 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int ...@@ -667,13 +707,133 @@ __global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int
atomicAdd(&dydk[batch_b][channel_idx][hi][wip], atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) * sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) *
(gdotv - integral)); (gdotv - integral));
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * sh_dy_ho_wo[channel_idx]); atomicAdd(&dydv[batch_b][channel_idx][hi][wip],
(alpha_inz / alpha_sum) * sh_dy_ho_wo[channel_idx]);
} }
} }
__syncthreads(); __syncthreads();
} }
// New kernel: s2_attention_bwd_dkvq_kernel_mbT
// This kernel assumes kx, vx, qy, dy, dydk, dydv, dydq are all [batch, ho, wo, channel] (transposed)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel_mbT(
int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 4;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float* sh_dy = sh_alpha_kvw + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
// Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][ho][wo][chan];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend - rbeg;
// First pass: find qdotk_max
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
gdotv += sh_dy[chan] * vx[batchId][hi][wip][chan];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][hi][wip][chan];
sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
}
}
integral /= alpha_sum;
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][ho][wo][chan] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
gdotv += sh_dy[chan] * vx[batchId][hi][wip][chan];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][ho][wo][chan];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][hi][wip][chan], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][hi][wip][chan], (alpha_inz / alpha_sum) * dyval);
}
}
}
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx, at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
...@@ -804,16 +964,25 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -804,16 +964,25 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
// enum for which kernel version
enum KERNEL_VERSION {
OLD_VERSION = 0,
HOWO_WARP_VERSION = 2,
};
auto version = HOWO_WARP_VERSION;
// auto version = OLD_VERSION;
if (version == OLD_VERSION) {
printf("old version\n");
torch::Tensor dydk = torch::zeros_like(qy); torch::Tensor dydk = torch::zeros_like(qy);
torch::Tensor dydv = torch::zeros_like(qy); torch::Tensor dydv = torch::zeros_like(qy);
torch::Tensor dydq = torch::zeros_like(qy); torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float); size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations // cuda grid y,z size limitations
assert(nlon_out < 65535); assert(nlon_out < 65535);
assert(batch_size < 65535); assert(batch_size < 65535);
...@@ -824,7 +993,16 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -824,7 +993,16 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// threads compute "blocks" of neighborhood and also "blocks" of channels // threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1); dim3 blockDim(256, 1, 1);
s2_attention_bwd_dkvq_kernel<<<gridDim, blockDim, sharedMemSize, stream>>>( // Define CUDA event variables for timing
cudaEvent_t start_event, stop_event;
cudaEventCreate(&start_event);
cudaEventCreate(&stop_event);
// Record the start event
cudaEventRecord(start_event, stream);
s2_attention_bwd_dkvq_kernel<<<
gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
...@@ -838,9 +1016,85 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -838,9 +1016,85 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
); );
// Record the stop event
cudaEventRecord(stop_event, stream);
cudaEventSynchronize(stop_event);
// Calculate elapsed time
float kernel_time_ms;
cudaEventElapsedTime(&kernel_time_ms, start_event, stop_event);
// Output the result
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// Old bwd kernel execution time: 803.477 ms
// std::cout << "Old bwd kernel execution time: " << kernel_time_ms << " ms" << std::endl;
// Cleanup events
cudaEventDestroy(start_event);
cudaEventDestroy(stop_event);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(dydk, dydv, dydq);
} else if (version == HOWO_WARP_VERSION) {
// Transpose to [batch, ho, wo, channel]
auto kxP = kx.permute({0,2,3,1}).contiguous();
auto vxP = vx.permute({0,2,3,1}).contiguous();
auto qyP = qy.permute({0,2,3,1}).contiguous();
auto dyP = dy.permute({0,2,3,1}).contiguous();
auto dydkP = torch::zeros_like(qyP);
auto dydvP = torch::zeros_like(qyP);
auto dydqP = torch::zeros_like(qyP);
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 4 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel_mbT<THREADS><<<
grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydkP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydvP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydqP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to [batch, channel, ho, wo]
auto dydk = dydkP.permute({0,3,1,2}).contiguous();
auto dydv = dydvP.permute({0,3,1,2}).contiguous();
auto dydq = dydqP.permute({0,3,1,2}).contiguous();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} else {
throw std::runtime_error("Invalid kernel version specified");
}
} }
// coding=utf-8 // coding=utf-8
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
...@@ -34,12 +34,37 @@ ...@@ -34,12 +34,37 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
__device__ static float atomicMax(float* address, float val) __device__ static float atomicMax(float* address, float val)
{ {
int* address_as_i = (int*) address; int* address_as_i = (int*) address;
...@@ -171,6 +196,101 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -171,6 +196,101 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
} }
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel_mbT(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float *shy = sh + threadIdx.y*num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out);
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
#else
shy[chan] = 0;
#endif
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][ wo][chan]*
kx[batchId][hi][wip][chan];
}
qdotk = __warp_sum(qdotk);
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][hi][wip][chan]*alpha;
}
qdotk_max = qdotk_max_tmp;
}
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
}
return;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx, at::Tensor vx,
...@@ -193,36 +313,44 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -193,36 +313,44 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations // transpose inputs so that channels are in the last dimension, allowing for
assert(nlon_out < 65535); // coalesced memory access
assert(batch_size < 65535); torch::Tensor kxP = kx.permute({0,2,3,1}).contiguous();
torch::Tensor vxP = vx.permute({0,2,3,1}).contiguous();
// block-parallel over output points and batches torch::Tensor qyP = qy.permute({0,2,3,1}).contiguous();
dim3 gridDim(nlat_out,nlon_out,batch_size); torch::Tensor y = torch::empty_like(qy);
// threads compute "blocks" of neighborhood and also "blocks" of channels dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
// note: blocksize of 512 runs into resource limits dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
dim3 blockDim(256,1,1);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, cudaEvent_t start, stop;
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), float milliseconds = 0;
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventCreate(&start));
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel_mbT<THREADS>
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
);
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment