Unverified Commit 26ce5cb5 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #77 from rietmann-nv/mr/bwd-channel-permute-experiments

Optimized CUDA kernels for S2 Attention (forward and backward)
parents 318fc76e 79fa6ad9
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -58,6 +58,7 @@ except ImportError as err: ...@@ -58,6 +58,7 @@ except ImportError as err:
attention_cuda_extension = None attention_cuda_extension = None
_cuda_extension_available = False _cuda_extension_available = False
_perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
class TestNeighborhoodAttentionS2(unittest.TestCase): class TestNeighborhoodAttentionS2(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -65,7 +66,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -65,7 +66,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index) torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
torch.manual_seed(333)
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333) torch.manual_seed(333)
...@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
] ],
skip_on_empty=True,
) )
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
...@@ -157,7 +158,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -157,7 +158,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], # [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
] ],
skip_on_empty=True,
) )
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere""" """Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
...@@ -212,5 +214,97 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -212,5 +214,97 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch") self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand(
[
# self attention
#[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
],
skip_on_empty=True,
)
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_gpu.requires_grad = False
# set up layers
time_layer_setup_start = torch.cuda.Event(enable_timing=True)
time_layer_setup_end = torch.cuda.Event(enable_timing=True)
time_layer_setup_start.record()
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
time_layer_setup_end.record()
torch.cuda.synchronize()
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
# time forward pass
for i in range(2):
# warmup
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_start = torch.cuda.Event(enable_timing=True)
time_forward_end = torch.cuda.Event(enable_timing=True)
time_forward_start.record()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_end.record()
torch.cuda.synchronize()
elapsed_time = time_forward_start.elapsed_time(time_forward_end)
if verbose:
print(f"Forward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < _perf_test_thresholds["fwd_ms"])
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device)#, memory_format=torch.channels_last)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device=self.device)
time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True)
for i in range(2):
# warmup
out_gpu.backward(out_grad, retain_graph=True)
time_backward_start.record()
out_gpu.backward(out_grad)
time_backward_end.record()
torch.cuda.synchronize()
elapsed_time = time_backward_start.elapsed_time(time_backward_end)
if verbose:
print(f"Backward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < _perf_test_thresholds["bwd_ms"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -43,7 +43,10 @@ except ImportError as err: ...@@ -43,7 +43,10 @@ except ImportError as err:
attention_cuda_extension = None attention_cuda_extension = None
_cuda_extension_available = False _cuda_extension_available = False
# s2 neighborhood attention forward pass
# uses qdotk_max update trick to avoid two loops when computing the softmax
# see e.g., https://arxiv.org/abs/1805.02867
# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
...@@ -61,7 +64,7 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: ...@@ -61,7 +64,7 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
for wo in range(nlon_out): for wo in range(nlon_out):
alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
qdotk_nz = torch.zeros((y.shape[0], zend-zstart,), dtype=y.dtype, device=y.device) qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
for idz in range(zstart, zend): for idz in range(zstart, zend):
nz_col_idx = col_idx[idz] nz_col_idx = col_idx[idz]
...@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: ...@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
# compute correlation & softmax numerator # compute correlation & softmax numerator
q_ho_wo = qy[:, :, ho, wo] q_ho_wo = qy[:, :, ho, wo]
k_hi_wip = kx[:, :, hi, wip] k_hi_wip = kx[:, :, hi, wip]
qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wip, dim=1) qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
qdotk_max, _ = torch.max(qdotk_nz, dim=1)
for idz in range(zstart, zend):
nz_col_idx = col_idx[idz]
# compute input indices from psi datastructure # tmp max
hi = nz_col_idx // nlon_in qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
# account for output shift and ensure positive index due to circular condition
wi = nz_col_idx % nlon_in
wip = (wi + wo) % nlon_in
alpha = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max)
# softmax denominator
alpha_sum[:] += alpha[:] * quad_weights[hi]
y[:,:,ho,wo] += alpha[:, None] * vx[:,:,hi,wip] * quad_weights[hi] # alpha sum update
alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
# update output
y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
# define new max
qdotk_max = qdotk_max_tmp
y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
......
...@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out); int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
...@@ -34,25 +34,65 @@ ...@@ -34,25 +34,65 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <limits> #include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>; using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>; using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
__device__ static float atomicMax(float* address, float val) #define WARP_SIZE (32)
{ #define FULL_MASK (0xFFFFFFFF)
int* address_as_i = (int*) address; #define THREADS (64)
int old = *address_as_i, assumed; #define DIV_UP(a,b) (((a)+((b)-1))/(b))
do {
assumed = old; #define NNZ_TRESH (32)
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed)))); #define CHECK_CUDA(call) { \
} while (assumed != old); cudaError_t err = call; \
return __int_as_float(old); if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
static __device__ float __warp_sum(float val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
} }
__global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
...@@ -60,115 +100,77 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out, ...@@ -60,115 +100,77 @@ __global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
{
// shared memory extern __shared__ float sh[];
extern __shared__ float sharedMem[]; float *shy = sh + threadIdx.y*num_channels;
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_qdotk_max = (float*)&sharedMem[1]; const uint64_t batchId = blockIdx.y;
float* sh_qy_ho_wo = (float *)&sharedMem[2]; const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
if (threadIdx.x == 0) { if (wid >= uint64_t(nlat_out)*nlon_in) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest(); return;
sh_alpha_sum[0] = 0.0;
} }
__syncthreads();
const int tidx = threadIdx.x;
int ho = blockIdx.x;
int wo = blockIdx.y; const int ho = wid / nlon_out;
int batch_b = blockIdx.z; const int wo = wid - (ho*nlon_out);
// load qy channels into shared memory for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) { #if 0
int channel_idx = channel_block_i*blockDim.x + threadIdx.x; // useless read, y is always zeroed before kernel is called
if(channel_idx >= num_channels) break; shy[chan] = y[batchId][chan][ho][wo];
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo]; #else
y[batch_b][channel_idx][ho][wo] = 0.0; shy[chan] = 0;
#endif
} }
__syncthreads(); float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
int psi_offset = psi_row_offset[ho]; const int rlen = rend-rbeg;
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads for(int off = 0; off < rlen; off++) {
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz]; const int64_t col = psi_col_idx[rbeg+off];
// compute input indices from psi datastructure const int hi = col / nlon_in;
int hi = nz_col_idx / nlon_in; const int wi = col - (hi*nlon_in);
// account for output shift and ensure positive index due to circular condition const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K) float qdotk = 0.0f;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip]; qdotk += qy[batchId][chan][ho][ wo]*
} kx[batchId][chan][hi][wip];
qdotk_max = std::max(qdotk_max, qdotk);
} }
qdotk = __warp_sum_cub(qdotk);
// collect thread-local qdotk max float qdotk_max_tmp;
atomicMax(&sh_qdotk_max[0], qdotk_max); float alpha;
__syncthreads(); float exp_save;
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0f; qdotk_max_tmp = max(qdotk_max, qdotk);
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) { alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
int idz = psi_block*blockDim.x + threadIdx.x; exp_save = expf(qdotk_max - qdotk_max_tmp);
float alpha_inz = 0.0;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz < psi_nnz_ho) {
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator with minus qdotk_max to avoid numerical overflow.
// Because qdotk_max is in both numerator and denominator (due to
// alpha_sum), it doesn't effect the solution other than removing overflow
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// thread-local sum alpha alpha_sum = alpha + alpha_sum*exp_save;
alpha_sum += alpha_inz;
// multiply alpha, vx, and quadrature weights for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) { shy[chan] = shy[chan]*exp_save + vx[batchId][chan][hi][wip]*alpha;
atomicAdd(&y[batch_b][channel_idx][ho][wo], alpha_inz * vx[batch_b][channel_idx][hi][wip]);
}
} }
qdotk_max = qdotk_max_tmp;
} }
// collect all alpha_sum across threads for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
atomicAdd(&sh_alpha_sum[0], alpha_sum); y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
__syncthreads();
// rebroadcast sum to all threads
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
y[batch_b][channel_idx][ho][wo] /= alpha_sum;
} }
return;
} }
...@@ -193,36 +195,73 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -193,36 +195,73 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// allocate output
torch::Tensor y = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
// cuda grid y,z size limitations auto k_channel_first = kx.strides()[1] == 1;
assert(nlon_out < 65535); auto v_channel_first = vx.strides()[1] == 1;
assert(batch_size < 65535); auto q_channel_first = qy.strides()[1] == 1;
// block-parallel over output points and batches // transpose inputs so that channels are in the last dimension, allowing for
dim3 gridDim(nlat_out,nlon_out,batch_size); // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// threads compute "blocks" of neighborhood and also "blocks" of channels //Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// note: blocksize of 512 runs into resource limits auto kxP = at::Tensor();
dim3 blockDim(256,1,1); if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
s2_attention_kernel<<<gridDim, blockDim, sharedMemSize,stream>>>( kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
uo_num_channels, nlon_in, nlat_out, nlon_out, } else {
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kxP = kx;
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), }
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS>
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>() quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
);
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
......
...@@ -33,10 +33,6 @@ ...@@ -33,10 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda,
"(Local) Attention gradient on S2 (gradient for q)");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment