Unverified Commit 26ce5cb5 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #77 from rietmann-nv/mr/bwd-channel-permute-experiments

Optimized CUDA kernels for S2 Attention (forward and backward)
parents 318fc76e 79fa6ad9
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......@@ -58,6 +58,7 @@ except ImportError as err:
attention_cuda_extension = None
_cuda_extension_available = False
_perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
class TestNeighborhoodAttentionS2(unittest.TestCase):
def setUp(self):
......@@ -65,10 +66,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333)
torch.manual_seed(333)
else:
self.device = torch.device("cpu")
torch.manual_seed(333)
torch.manual_seed(333)
@parameterized.expand(
[
......@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
]
],
skip_on_empty=True,
)
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
......@@ -157,7 +158,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
]
],
skip_on_empty=True,
)
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
......@@ -212,5 +214,97 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand(
[
# self attention
#[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
],
skip_on_empty=True,
)
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_gpu.requires_grad = False
# set up layers
time_layer_setup_start = torch.cuda.Event(enable_timing=True)
time_layer_setup_end = torch.cuda.Event(enable_timing=True)
time_layer_setup_start.record()
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
time_layer_setup_end.record()
torch.cuda.synchronize()
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
# time forward pass
for i in range(2):
# warmup
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_start = torch.cuda.Event(enable_timing=True)
time_forward_end = torch.cuda.Event(enable_timing=True)
time_forward_start.record()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_end.record()
torch.cuda.synchronize()
elapsed_time = time_forward_start.elapsed_time(time_forward_end)
if verbose:
print(f"Forward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < _perf_test_thresholds["fwd_ms"])
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device=self.device)
time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True)
for i in range(2):
# warmup
out_gpu.backward(out_grad, retain_graph=True)
time_backward_start.record()
out_gpu.backward(out_grad)
time_backward_end.record()
torch.cuda.synchronize()
elapsed_time = time_backward_start.elapsed_time(time_backward_end)
if verbose:
print(f"Backward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < _perf_test_thresholds["bwd_ms"])
if __name__ == "__main__":
unittest.main()
......@@ -43,25 +43,28 @@ except ImportError as err:
attention_cuda_extension = None
_cuda_extension_available = False
# s2 neighborhood attention forward pass
# uses qdotk_max update trick to avoid two loops when computing the softmax
# see e.g., https://arxiv.org/abs/1805.02867
# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
# prepare result tensor
y = torch.zeros_like(qy)
for ho in range(nlat_out):
# get number of nonzeros
# get number of nonzeros
zstart = row_off[ho]
zend = row_off[ho+1]
for wo in range(nlon_out):
alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
qdotk_nz = torch.zeros((y.shape[0], zend-zstart,), dtype=y.dtype, device=y.device)
qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
......@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
# compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wip, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
# compute input indices from psi datastructure
hi = nz_col_idx // nlon_in
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
alpha = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max)
# softmax denominator
alpha_sum[:] += alpha[:] * quad_weights[hi]
# tmp max
qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
y[:,:,ho,wo] += alpha[:, None] * vx[:,:,hi,wip] * quad_weights[hi]
# alpha sum update
alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
# update output
y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
# define new max
qdotk_max = qdotk_max_tmp
y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
......
......@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
......@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
......@@ -29,6 +29,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include "c10/core/MemoryFormat.h"
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
......@@ -36,763 +37,206 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <ctime>
#include <cub/cub.cuh>
#include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__global__ void
s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem; // 1
float* sh_qdotk_max = (float*)&sharedMem[1]; // 1
float* sh_qy_ho_wo = (float*)&sharedMem[2]; // num_channels
if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0;
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
#ifndef WARP_SIZE
#define WARP_SIZE (32)
#endif
#ifndef FULL_MASK
#define FULL_MASK (0xFFFFFFFF)
#endif
#ifndef THREADS
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#endif
#include <iostream>
#include <chrono>
#include <string>
class ScopeTimer {
public:
explicit ScopeTimer(const std::string& label = "")
: label_(label), start_(std::chrono::high_resolution_clock::now()) {}
~ScopeTimer() {
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
// form alpha & sum alpha
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
};
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
// collect thread-local alpha_sum
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
// alpha * dy * omega / alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
return val;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, dy, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * dy[batch_b][channel_idx][ho][wo]);
}
}
}
at::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydv = torch::zeros_like(vx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dv_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydv;
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
__global__ void
s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_integral = (float *)&sharedMem[1 + num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 2 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
// This kernel computes the backward pass for the S2 attention mechanism, using
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel(
int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels;
float* sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
// Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend - rbeg;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
// First pass: find qdotk_max
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
// compute max over all threads
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
// Second pass: accumulate alpha_sum, integral, and shared stats
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
// softmax numerator
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha & integral
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
}
// block sum thread-local alpha_sum and integral
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finish integral computation
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// broadcast sum and integral back to thread-local registers
integral = sh_integral[0];
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz/alpha_sum) * (gdotv - integral));
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
}
}
__syncthreads();
}
__global__ void
s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_alpha_k = (float *)&sharedMem[1 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[1 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[1 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[1 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[1 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
integral /= alpha_sum;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
// softmax numerator
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][chan][ho][wo];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
}
}
// sum thread-local alpha_sums across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
}
__global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>
dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_integral = (float*)&sharedMem[1];
float *sh_qy_ho_wo = (float *)&sharedMem[2];
float *sh_alpha_k = (float *)&sharedMem[2 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[2 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[2 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
integral += alpha_inz * gdotv;
}
// sum thread-local alpha_sums & integral across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finalize integral
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// "broadcast" alpha sum & integral back to thread-local registers
alpha_sum = sh_alpha_sum[0];
integral = sh_integral[0];
// dq
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
__syncthreads();
// dk & dv
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) *
(gdotv - integral));
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * sh_dy_ho_wo[channel_idx]);
}
}
__syncthreads();
}
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(kx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (2*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dk_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydk;
}
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (5*uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dq_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydq;
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
......@@ -804,43 +248,110 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(qy);
torch::Tensor dydv = torch::zeros_like(qy);
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
auto dyP = at::Tensor();
if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
dyP = dy;
}
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop();
size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dkvq_kernel<<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<
grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if(!k_channel_first) dydk = dydk.contiguous();
if(!v_channel_first) dydv = dydv.contiguous();
if(!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq);
}
......@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
......@@ -34,147 +34,149 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
__global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_qdotk_max = (float*)&sharedMem[1];
float* sh_qy_ho_wo = (float *)&sharedMem[2];
if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0;
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
y[batch_b][channel_idx][ho][wo] = 0.0;
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float *shy = sh + threadIdx.y*num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out);
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
#else
shy[chan] = 0;
#endif
}
__syncthreads();
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend-rbeg;
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
for(int off = 0; off < rlen; off++) {
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
const int64_t col = psi_col_idx[rbeg+off];
int nz_col_idx = psi_col_idx[psi_offset+idz];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float qdotk = 0.0f;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][ wo]*
kx[batchId][chan][hi][wip];
}
qdotk_max = std::max(qdotk_max, qdotk);
}
qdotk = __warp_sum_cub(qdotk);
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float qdotk_max_tmp;
float alpha;
float exp_save;
float alpha_sum = 0.0f;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
float alpha_inz = 0.0;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz < psi_nnz_ho) {
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator with minus qdotk_max to avoid numerical overflow.
// Because qdotk_max is in both numerator and denominator (due to
// alpha_sum), it doesn't effect the solution other than removing overflow
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// thread-local sum alpha
alpha_sum += alpha_inz;
// multiply alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&y[batch_b][channel_idx][ho][wo], alpha_inz * vx[batch_b][channel_idx][hi][wip]);
}
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][chan][hi][wip]*alpha;
}
qdotk_max = qdotk_max_tmp;
}
// collect all alpha_sum across threads
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// rebroadcast sum to all threads
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
y[batch_b][channel_idx][ho][wo] /= alpha_sum;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
}
return;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor qy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
......@@ -193,36 +195,73 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
// note: blocksize of 512 runs into resource limits
dim3 blockDim(256,1,1);
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS>
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK();
......
......@@ -33,10 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda,
"(Local) Attention gradient on S2 (gradient for q)");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment