Unverified Commit 26ce5cb5 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #77 from rietmann-nv/mr/bwd-channel-permute-experiments

Optimized CUDA kernels for S2 Attention (forward and backward)
parents 318fc76e 79fa6ad9
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -58,6 +58,7 @@ except ImportError as err: ...@@ -58,6 +58,7 @@ except ImportError as err:
attention_cuda_extension = None attention_cuda_extension = None
_cuda_extension_available = False _cuda_extension_available = False
_perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
class TestNeighborhoodAttentionS2(unittest.TestCase): class TestNeighborhoodAttentionS2(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -65,7 +66,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -65,7 +66,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index) torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
torch.manual_seed(333)
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333) torch.manual_seed(333)
...@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
] ],
skip_on_empty=True,
) )
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
...@@ -157,7 +158,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -157,7 +158,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], # [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
] ],
skip_on_empty=True,
) )
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere""" """Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
...@@ -212,5 +214,97 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -212,5 +214,97 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch") self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand(
[
# self attention
#[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
],
skip_on_empty=True,
)
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_gpu.requires_grad = False
# set up layers
time_layer_setup_start = torch.cuda.Event(enable_timing=True)
time_layer_setup_end = torch.cuda.Event(enable_timing=True)
time_layer_setup_start.record()
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
time_layer_setup_end.record()
torch.cuda.synchronize()
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
# time forward pass
for i in range(2):
# warmup
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_start = torch.cuda.Event(enable_timing=True)
time_forward_end = torch.cuda.Event(enable_timing=True)
time_forward_start.record()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_end.record()
torch.cuda.synchronize()
elapsed_time = time_forward_start.elapsed_time(time_forward_end)
if verbose:
print(f"Forward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < _perf_test_thresholds["fwd_ms"])
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device=self.device)
time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True)
for i in range(2):
# warmup
out_gpu.backward(out_grad, retain_graph=True)
time_backward_start.record()
out_gpu.backward(out_grad)
time_backward_end.record()
torch.cuda.synchronize()
elapsed_time = time_backward_start.elapsed_time(time_backward_end)
if verbose:
print(f"Backward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < _perf_test_thresholds["bwd_ms"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -43,7 +43,10 @@ except ImportError as err: ...@@ -43,7 +43,10 @@ except ImportError as err:
attention_cuda_extension = None attention_cuda_extension = None
_cuda_extension_available = False _cuda_extension_available = False
# s2 neighborhood attention forward pass
# uses qdotk_max update trick to avoid two loops when computing the softmax
# see e.g., https://arxiv.org/abs/1805.02867
# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
...@@ -61,7 +64,7 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: ...@@ -61,7 +64,7 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
for wo in range(nlon_out): for wo in range(nlon_out):
alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
qdotk_nz = torch.zeros((y.shape[0], zend-zstart,), dtype=y.dtype, device=y.device) qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
for idz in range(zstart, zend): for idz in range(zstart, zend):
nz_col_idx = col_idx[idz] nz_col_idx = col_idx[idz]
...@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: ...@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
# compute correlation & softmax numerator # compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo] q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip] k_hi_wip = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wip, dim=1) qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure # tmp max
hi = nz_col_idx // nlon_in qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
alpha = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max)
# softmax denominator
alpha_sum[:] += alpha[:] * quad_weights[hi]
y[:,:,ho,wo] += alpha[:, None] * vx[:,:,hi,wip] * quad_weights[hi] # alpha sum update
alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
# update output
y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
# define new max
qdotk_max = qdotk_max_tmp
y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
......
...@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out); int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh" #include "attention.cuh"
#include "c10/core/MemoryFormat.h"
#include <ATen/core/TensorAccessor.h> #include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
...@@ -36,755 +37,198 @@ ...@@ -36,755 +37,198 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <ctime>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; #ifndef WARP_SIZE
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; #define WARP_SIZE (32)
#endif
__device__ static float atomicMax(float* address, float val) #ifndef FULL_MASK
{ #define FULL_MASK (0xFFFFFFFF)
int* address_as_i = (int*) address; #endif
int old = *address_as_i, assumed; #ifndef THREADS
do { #define THREADS (64)
assumed = old; #endif
old = ::atomicCAS(address_as_i, assumed, #ifndef DIV_UP
__float_as_int(::fmaxf(val, __int_as_float(assumed)))); #define DIV_UP(a,b) (((a)+((b)-1))/(b))
} while (assumed != old); #endif
return __int_as_float(old); #ifndef CHECK_CUDA
} #define CHECK_CUDA(call) { \
cudaError_t err = call; \
__global__ void if( cudaSuccess != err) { \
s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, __FILE__, __LINE__, cudaGetErrorString( err) ); \
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, exit(EXIT_FAILURE); \
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, }}
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, #endif
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, #include <iostream>
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, #include <chrono>
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) #include <string>
{
// shared memory class ScopeTimer {
extern __shared__ float sharedMem[]; public:
explicit ScopeTimer(const std::string& label = "")
float* sh_alpha_sum = (float*)&sharedMem; // 1 : label_(label), start_(std::chrono::high_resolution_clock::now()) {}
float* sh_qdotk_max = (float*)&sharedMem[1]; // 1
float* sh_qy_ho_wo = (float*)&sharedMem[2]; // num_channels ~ScopeTimer() {
auto end = std::chrono::high_resolution_clock::now();
if (threadIdx.x == 0) { auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
sh_qdotk_max[0] = std::numeric_limits<float>::lowest(); std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
sh_alpha_sum[0] = 0.0; }
}
__syncthreads(); private:
std::string label_;
int ho = blockIdx.x; std::chrono::high_resolution_clock::time_point start_;
int wo = blockIdx.y; };
int batch_b = blockIdx.z;
static __device__ float __warp_sum(float val) {
// load qy channels into shared memory #pragma unroll
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) { for(int i = WARP_SIZE/2; i; i /= 2) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x; val += __shfl_xor_sync(FULL_MASK, val, i);
if(channel_idx >= num_channels) break; }
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo]; return val;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
// form alpha & sum alpha
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
}
// collect thread-local alpha_sum
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
// alpha * dy * omega / alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, dy, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * dy[batch_b][channel_idx][ho][wo]);
}
}
}
at::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydv = torch::zeros_like(vx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dv_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydv;
} }
__global__ void // easier to understand version of manual shfl_xor_sync, performance appears similar
s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, static __device__ float __warp_sum_cub(float val) {
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, // use cub to reduce within a warp
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_integral = (float *)&sharedMem[1 + num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 2 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = max(qdotk_max, qdotk);
}
// compute max over all threads
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha & integral
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
}
// block sum thread-local alpha_sum and integral // 1. Compute sum (initially only in lane 0)
atomicAdd(&sh_alpha_sum[0], alpha_sum); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
atomicAdd(&sh_integral[0], integral); // 2. Broadcast sum to all threads
__syncthreads(); sum = __shfl_sync(0xFFFFFFFF, sum, 0);
// finish integral computation return sum;
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// broadcast sum and integral back to thread-local registers
integral = sh_integral[0];
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz/alpha_sum) * (gdotv - integral));
}
}
__syncthreads();
} }
__global__ void // This kernel computes the backward pass for the S2 attention mechanism, using
s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, // shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel(
int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_alpha_k = (float *)&sharedMem[1 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[1 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[1 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[1 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[1 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
}
// sum thread-local alpha_sums across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
}
__global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>
dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
{
// shared memory extern __shared__ float sh[];
extern __shared__ float sharedMem[]; float* sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float *sh_alpha_sum = (float *)&sharedMem; float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float* sh_integral = (float*)&sharedMem[1]; float *sh_dy = sh_alpha_kvw + num_channels;
float *sh_qy_ho_wo = (float *)&sharedMem[2]; float* sh_qy = sh_dy + num_channels;
float *sh_alpha_k = (float *)&sharedMem[2 + num_channels]; // (optionally, could use more shared memory for other intermediates)
float *sh_alpha_vw = (float *)&sharedMem[2 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[2 + 3*num_channels]; const uint64_t batchId = blockIdx.y;
float *sh_dy_ho_wo = (float *)&sharedMem[2 + 4 * num_channels]; const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
float *sh_qdotk_max = (float *)&sharedMem[2 + 5 * num_channels]; if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x;
if (threadIdx.x == 0) { const int ho = wid / nlon_out;
sh_alpha_sum[0] = 0.0; const int wo = wid - (ho * nlon_out);
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest(); // Zero shared memory
} for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads(); __syncthreads();
int ho = blockIdx.x; const int64_t rbeg = psi_row_offset[ho];
int wo = blockIdx.y; const int64_t rend = psi_row_offset[ho+1];
int batch_b = blockIdx.z; const int rlen = rend - rbeg;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz]; // First pass: find qdotk_max
for (int off = 0; off < rlen; off++) {
// compute input indices from psi datastructure const int64_t col = psi_col_idx[rbeg + off];
int hi = nz_col_idx / nlon_in; const int hi = col / nlon_in;
// account for output shift and ensure positive index due to circular condition const int wi = col - (hi * nlon_in);
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f; float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip]; qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
} }
qdotk_max = std::max(qdotk, qdotk_max); qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
} }
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers // Second pass: accumulate alpha_sum, integral, and shared stats
qdotk_max = sh_qdotk_max[0]; for (int off = 0; off < rlen; off++) {
float alpha_sum = 0.0; const int64_t col = psi_col_idx[rbeg + off];
float integral = 0.0; const int hi = col / nlon_in;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) { const int wi = col - (hi * nlon_in);
int idz = psi_block*blockDim.x + threadIdx.x; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
// skip if index >= length of psi_idx because last loop iteration will have extra threads float qdotk = 0.0f, gdotv = 0.0f;
if(idz >= psi_nnz_ho) break; for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
int nz_col_idx = psi_col_idx[psi_offset+idz]; gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
} }
// softmax numerator qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz; alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
integral += alpha_inz * gdotv; integral += alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
}
} }
// sum thread-local alpha_sums & integral across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finalize integral
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// "broadcast" alpha sum & integral back to thread-local registers
alpha_sum = sh_alpha_sum[0];
integral = sh_integral[0];
// dq integral /= alpha_sum;
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum); // Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
} }
__syncthreads();
// dk & dv // Third pass: accumulate gradients for k and v
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) { for (int off = 0; off < rlen; off++) {
int idz = psi_block*blockDim.x + threadIdx.x; const int64_t col = psi_col_idx[rbeg + off];
// skip if index >= length of psi_idx because last loop iteration will have extra threads const int hi = col / nlon_in;
if(idz >= psi_nnz_ho) break; const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
int nz_col_idx = psi_col_idx[psi_offset+idz]; float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
// compute input indices from psi datastructure qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
int hi = nz_col_idx / nlon_in; gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
} }
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { float qyval = qy[batchId][chan][ho][wo];
atomicAdd(&dydk[batch_b][channel_idx][hi][wip], float dyval = sh_dy[chan];
sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) * atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
(gdotv - integral)); atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * sh_dy_ho_wo[channel_idx]);
} }
} }
__syncthreads();
}
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(kx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (2*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dk_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydk;
} }
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (5*uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dq_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydq;
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
...@@ -804,43 +248,110 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -804,43 +248,110 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(qy); auto k_channel_first = kx.strides()[1] == 1;
torch::Tensor dydv = torch::zeros_like(qy); auto v_channel_first = vx.strides()[1] == 1;
torch::Tensor dydq = torch::zeros_like(qy); auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
auto dyP = at::Tensor();
if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
dyP = dy;
}
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop();
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
assert(nlon_out < 65535); dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
assert(batch_size < 65535); size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
// block-parallel over output points and batches cudaEvent_t start, stop;
dim3 gridDim(nlat_out,nlon_out,batch_size); float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
// threads compute "blocks" of neighborhood and also "blocks" of channels s2_attention_bwd_dkvq_kernel<THREADS><<<
dim3 blockDim(256, 1, 1); grid, block, shared_size, stream>>>(
s2_attention_bwd_dkvq_kernel<<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
);
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if(!k_channel_first) dydk = dydk.contiguous();
if(!v_channel_first) dydv = dydv.contiguous();
if(!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} }
...@@ -34,25 +34,65 @@ ...@@ -34,25 +34,65 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
__device__ static float atomicMax(float* address, float val) #define WARP_SIZE (32)
{ #define FULL_MASK (0xFFFFFFFF)
int* address_as_i = (int*) address; #define THREADS (64)
int old = *address_as_i, assumed; #define DIV_UP(a,b) (((a)+((b)-1))/(b))
do {
assumed = old; #define NNZ_TRESH (32)
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed)))); #define CHECK_CUDA(call) { \
} while (assumed != old); cudaError_t err = call; \
return __int_as_float(old); if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
} }
__global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
...@@ -60,115 +100,77 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -60,115 +100,77 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
{
// shared memory extern __shared__ float sh[];
extern __shared__ float sharedMem[]; float *shy = sh + threadIdx.y*num_channels;
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_qdotk_max = (float*)&sharedMem[1]; const uint64_t batchId = blockIdx.y;
float* sh_qy_ho_wo = (float *)&sharedMem[2]; const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
if (threadIdx.x == 0) { if (wid >= uint64_t(nlat_out)*nlon_in) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest(); return;
sh_alpha_sum[0] = 0.0;
} }
__syncthreads();
const int tidx = threadIdx.x;
int ho = blockIdx.x;
int wo = blockIdx.y; const int ho = wid / nlon_out;
int batch_b = blockIdx.z; const int wo = wid - (ho*nlon_out);
// load qy channels into shared memory for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) { #if 0
int channel_idx = channel_block_i*blockDim.x + threadIdx.x; // useless read, y is always zeroed before kernel is called
if(channel_idx >= num_channels) break; shy[chan] = y[batchId][chan][ho][wo];
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo]; #else
y[batch_b][channel_idx][ho][wo] = 0.0; shy[chan] = 0;
#endif
} }
__syncthreads(); float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
int psi_offset = psi_row_offset[ho]; const int rlen = rend-rbeg;
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads for(int off = 0; off < rlen; off++) {
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz]; const int64_t col = psi_col_idx[rbeg+off];
// compute input indices from psi datastructure const int hi = col / nlon_in;
int hi = nz_col_idx / nlon_in; const int wi = col - (hi*nlon_in);
// account for output shift and ensure positive index due to circular condition const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K) float qdotk = 0.0f;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip]; qdotk += qy[batchId][chan][ho][ wo]*
} kx[batchId][chan][hi][wip];
qdotk_max = std::max(qdotk_max, qdotk);
} }
qdotk = __warp_sum_cub(qdotk);
// collect thread-local qdotk max float qdotk_max_tmp;
atomicMax(&sh_qdotk_max[0], qdotk_max); float alpha;
__syncthreads(); float exp_save;
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0f; qdotk_max_tmp = max(qdotk_max, qdotk);
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) { alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
int idz = psi_block*blockDim.x + threadIdx.x; exp_save = expf(qdotk_max - qdotk_max_tmp);
float alpha_inz = 0.0;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz < psi_nnz_ho) {
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator with minus qdotk_max to avoid numerical overflow.
// Because qdotk_max is in both numerator and denominator (due to
// alpha_sum), it doesn't effect the solution other than removing overflow
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// thread-local sum alpha alpha_sum = alpha + alpha_sum*exp_save;
alpha_sum += alpha_inz;
// multiply alpha, vx, and quadrature weights for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { shy[chan] = shy[chan]*exp_save + vx[batchId][chan][hi][wip]*alpha;
atomicAdd(&y[batch_b][channel_idx][ho][wo], alpha_inz * vx[batch_b][channel_idx][hi][wip]);
}
} }
qdotk_max = qdotk_max_tmp;
} }
// collect all alpha_sum across threads for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
atomicAdd(&sh_alpha_sum[0], alpha_sum); y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
__syncthreads();
// rebroadcast sum to all threads
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
y[batch_b][channel_idx][ho][wo] /= alpha_sum;
} }
return;
} }
...@@ -193,36 +195,73 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -193,36 +195,73 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations auto k_channel_first = kx.strides()[1] == 1;
assert(nlon_out < 65535); auto v_channel_first = vx.strides()[1] == 1;
assert(batch_size < 65535); auto q_channel_first = qy.strides()[1] == 1;
// block-parallel over output points and batches // transpose inputs so that channels are in the last dimension, allowing for
dim3 gridDim(nlat_out,nlon_out,batch_size); // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// threads compute "blocks" of neighborhood and also "blocks" of channels //Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// note: blocksize of 512 runs into resource limits auto kxP = at::Tensor();
dim3 blockDim(256,1,1); if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>( kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
uo_num_channels, nlon_in, nlat_out, nlon_out, } else {
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kxP = kx;
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), }
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS>
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
);
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
......
...@@ -33,10 +33,6 @@ ...@@ -33,10 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda,
"(Local) Attention gradient on S2 (gradient for q)");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment