Unverified Commit 4aaff021 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #91 from NVIDIA/maurob/devel

Attention Backward improvement
parents ab44ba59 fa58767d
......@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags = []
if profile_mode:
nvcc_extra_flags.append("-lineinfo")
nvcc_extra_flags.append("-Xptxas=-v")
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
......@@ -102,6 +103,7 @@ def get_ext_modules():
CUDAExtension(
name="attention_cuda_extension",
sources=[
"torch_harmonics/csrc/attention/attention_utils.cu",
"torch_harmonics/csrc/attention/attention_fwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_bwd_cuda.cu",
"torch_harmonics/csrc/attention/attention_interface.cu",
......
......@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
],
skip_on_empty=True,
......@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
],
......
......@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
B, _, H, W = grad_output.shape
grad_output = grad_output.reshape(B*nh, -1, H, W)
# save type and convert to float32
kw_dtype = kw.dtype
vw_dtype = vw.dtype
qw_dtype = qw.dtype
kw = kw.to(torch.float32).contiguous()
vw = vw.to(torch.float32).contiguous()
qw = qw.to(torch.float32).contiguous()
grad_output = grad_output.to(torch.float32).contiguous()
dkw,dvw,dqw = attention_cuda_extension.backward_dkvq(kw, vw, qw, grad_output,
quad_weights,
col_idx, row_off,
......@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_, C, H, W = dqw.shape
dqw = dqw.reshape(B, -1, H, W)
# convert precision
dkw = dkw.to(dtype=kw_dtype)
dvw = dvw.to(dtype=vw_dtype)
dqw = dqw.to(dtype=qw_dtype)
# input grads
dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
......
......@@ -34,7 +34,11 @@
#include <cstdint>
#include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast))
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
......
......@@ -41,33 +41,18 @@
#include <cub/cub.cuh>
#include <limits>
#ifndef WARP_SIZE
#define WARP_SIZE (32)
#endif
#ifndef FULL_MASK
#define FULL_MASK (0xFFFFFFFF)
#endif
#ifndef THREADS
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#endif
#include "cudamacro.h"
#include "attention_utils.cuh"
#include <iostream>
#include <chrono>
#include <string>
#define THREADS (64)
#define MAX_LOCAL_ARR_LEN (16)
#if 0
class ScopeTimer
{
public:
......@@ -88,13 +73,6 @@ class ScopeTimer
std::chrono::high_resolution_clock::time_point start_;
};
static __device__ float __warp_sum(float val)
{
#pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val)
{
......@@ -216,6 +194,697 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
}
}
}
#endif
// BEGIN backward kernels and functions
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template<int BDIM_X,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM_X)
void s2_attn_bwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const FLOATV_T *__restrict__ dy,
const int32_t *__restrict__ row_idx,
const int64_t *__restrict__ row_off,
const int64_t *__restrict__ col_idx,
const float *__restrict__ quad_weights,
FLOATV_T *__restrict__ dkx,
FLOATV_T *__restrict__ dvx,
FLOATV_T *__restrict__ dqy) {
extern __shared__ __align__(sizeof(float4)) float shext[];
// for dqy
FLOATV_T *sh_alpha_k__ = reinterpret_cast<FLOATV_T *>(shext) + threadIdx.y * nchan*5;
FLOATV_T *sh_alpha_vw_ = sh_alpha_k__ + nchan;
FLOATV_T *sh_alpha_kvw = sh_alpha_vw_ + nchan;
FLOATV_T *sh_dy = sh_alpha_kvw + nchan;
FLOATV_T *sh_qy = sh_dy + nchan;
const int batch = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
const int tidx = threadIdx.x;
// use permuted rows
const int h = wid / nlon_out;
const int wo = wid - (h*nlon_out);
const int ho = row_idx[h];
// offset input tensors
kx += int64_t(batch)*nlat_in*nlon_in*nchan;
vx += int64_t(batch)*nlat_in*nlon_in*nchan;
qy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan;
dy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan;
// offset output tensors
dkx += int64_t(batch)*nlat_in*nlon_in*nchan;
dvx += int64_t(batch)*nlat_in*nlon_in*nchan;
dqy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan;
// zero/init shared memory
for (int chan = tidx; chan < nchan; chan += WARP_SIZE) {
sh_alpha_k__[chan] = __vset<FLOATV_T>(0.0f);
sh_alpha_vw_[chan] = __vset<FLOATV_T>(0.0f);
sh_alpha_kvw[chan] = __vset<FLOATV_T>(0.0f);
sh_dy[chan] = dy[chan];
sh_qy[chan] = qy[chan];
}
#if __CUDA_ARCH__ < 900
// for architectures < 9.0, sh_dy and sh_qy will be read
// as individual floats at the end of the kernel, which
// breaks the assumption that each FLOATV_T location is
// written to and read by the same thread throughout the
// kernel, in the case FLOATV_T==float4
if constexpr(std::is_same<FLOATV_T, float4>::value) { __syncwarp(); }
#endif
// for dkx, dvx, dqy
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
// for dkx
float integral = 0.0f;
const int64_t rbeg = row_off[ho];
const int64_t rend = row_off[ho+1];
col_idx += rbeg;
const int rlen = rend - rbeg;
// accumulate alpha_sum, integral, and shared stats,
// along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = col_idx[off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotk_v = __vset<FLOATV_T>(0.0f);
FLOATV_T gdotv_v = __vset<FLOATV_T>(0.0f);
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[chan], _kx[chan]));
gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan]));
}
const float qdotk = __warp_sum(__vred(qdotk_v));
const float gdotv = __warp_sum(__vred(gdotv_v));
const float qdotk_max_tmp = max(qdotk_max, qdotk);
const float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
const float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
const float ainz_gdotv = alpha_inz * gdotv;
for (int chan = tidx; chan < nchan; chan += WARP_SIZE) {
const FLOATV_T kxval = _kx[chan];
sh_alpha_k__[chan] = __vadd(__vscale(max_correction, sh_alpha_k__[chan]), __vscale(alpha_inz, kxval));
sh_alpha_vw_[chan] = __vadd(__vscale(max_correction, sh_alpha_vw_[chan]), __vset<FLOATV_T>(ainz_gdotv));
sh_alpha_kvw[chan] = __vadd(__vscale(max_correction, sh_alpha_kvw[chan]), __vscale(ainz_gdotv, kxval));
}
qdotk_max = qdotk_max_tmp;
}
const float alpha_sum_inv = 1.0f / alpha_sum;
integral *= alpha_sum_inv;
// Write dqy
for (int chan = tidx; chan < nchan; chan += WARP_SIZE) {
dqy[chan] = __vscale(alpha_sum_inv * alpha_sum_inv,
__vsub(__vscale(alpha_sum, sh_alpha_kvw[chan]),
__vmul(sh_alpha_vw_[chan], sh_alpha_k__[chan])));
}
// accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = col_idx[off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotk_v = __vset<FLOATV_T>(0.0f);
FLOATV_T gdotv_v = __vset<FLOATV_T>(0.0f);
for (int chan = tidx; chan < nchan; chan += WARP_SIZE) {
qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[chan], _kx[chan]));
gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan]));
}
const float qdotk = __warp_sum(__vred(qdotk_v));
const float gdotv = __warp_sum(__vred(gdotv_v));
const float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
FLOATV_T *_dkx = dkx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T *_dvx = dvx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const float alpha_mul = alpha_inz * alpha_sum_inv;
const float scale_fact_qy = (gdotv - integral)*alpha_mul;
const float scale_fact_dy = alpha_mul;
// float4, 128-bit atomics are only supported by devices of compute
// capability 9.x+, so on older devices we resort to 32-bit atomics
#if __CUDA_ARCH__ < 900
// to use 32-bit operations on consecutve addresses
float *sh_qy_scl = reinterpret_cast<float *>(sh_qy);
float *sh_dy_scl = reinterpret_cast<float *>(sh_dy);
float *_dkx_scl = reinterpret_cast<float *>(_dkx);
float *_dvx_scl = reinterpret_cast<float *>(_dvx);
constexpr int VEC_SIZE = sizeof(FLOATV_T)/sizeof(float);
// 32-bit, consecutive atomics to glmem;
// strided atomics results in a severe slowdown
for (int chan = tidx; chan < nchan*VEC_SIZE; chan += WARP_SIZE) {
atomicAdd(_dkx_scl + chan, scale_fact_qy * sh_qy_scl[chan]);
atomicAdd(_dvx_scl + chan, scale_fact_dy * sh_dy_scl[chan]);
}
#else
// 128-bit, consecutive atomics to glmem
for (int chan = tidx; chan < nchan; chan += WARP_SIZE) {
atomicAdd(_dkx + chan, __vscale(scale_fact_qy, sh_qy[chan]));
atomicAdd(_dvx + chan, __vscale(scale_fact_dy, sh_dy[chan]));
}
#endif
}
return;
}
// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1)
template<int BDIM_X,
int BDIM_Y,
int NLOC,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void s2_attn_bwd_special_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const FLOATV_T *__restrict__ dy,
const int32_t *__restrict__ row_idx,
const int64_t *__restrict__ row_off,
const int64_t *__restrict__ col_idx,
const float *__restrict__ quad_weights,
FLOATV_T *__restrict__ dkx,
FLOATV_T *__restrict__ dvx,
FLOATV_T *__restrict__ dqy) {
static_assert(0 == (BDIM_X & (BDIM_X-1)));
static_assert(0 == (BDIM_Y & (BDIM_Y-1)));
static_assert((BDIM_X == 32 && BDIM_Y > 1) ||
(BDIM_X > 32 && BDIM_Y == 1)) ;
constexpr int NLOC_M1 = NLOC-1;
const int tidx = threadIdx.x;
const int batch = blockIdx.y;
const uint64_t ctaid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (ctaid >= uint64_t(nlat_out)*nlon_in) {
return;
}
extern __shared__ __align__(sizeof(float4)) float shext[];
FLOATV_T *sh_dy = reinterpret_cast<FLOATV_T *>(shext) + threadIdx.y*nchan*2 + tidx;
FLOATV_T *sh_qy = sh_dy + nchan;
// for dqy
FLOATV_T loc_k__[NLOC];
FLOATV_T loc_vw_[NLOC];
FLOATV_T loc_kvw[NLOC];
// use permuted rows
const int h = ctaid / nlon_out;
const int wo = ctaid - (h*nlon_out);
const int ho = row_idx[h];
// offset input tensors
kx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
vx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
qy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
dy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
// offset output tensors
dkx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
dvx += int64_t(batch)*nlat_in*nlon_in*nchan + tidx;
dqy += int64_t(batch)*nlat_out*nlon_out*nchan + int64_t(ho)*nlon_out*nchan + int64_t(wo)*nchan + tidx;
#pragma unroll
for(int i = 0; i < NLOC; i++) {
loc_k__[i] = __vset<FLOATV_T>(0.0f);
loc_vw_[i] = __vset<FLOATV_T>(0.0f);
loc_kvw[i] = __vset<FLOATV_T>(0.0f);
}
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
sh_dy[i*BDIM_X] = dy[i*BDIM_X];
sh_qy[i*BDIM_X] = qy[i*BDIM_X];
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
sh_dy[NLOC_M1*BDIM_X] = dy[NLOC_M1*BDIM_X];
sh_qy[NLOC_M1*BDIM_X] = qy[NLOC_M1*BDIM_X];
}
#if __CUDA_ARCH__ < 900
// for architectures < 9.0, sh_dy and sh_qy will be read
// as individual floats at the end of the kernel, which
// breaks the assumption that each FLOATV_T location is
// written to and read by the same thread throughout the
// kernel, in the case FLOATV_T==float4
if constexpr(std::is_same<FLOATV_T, float4>::value) {
if constexpr(BDIM_X == 32) { __syncwarp(); }
else { __syncthreads(); }
}
#endif
// for dkx, dvx, dqy
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
// for dkx
float integral = 0.0f;
const int64_t rbeg = row_off[ho];
const int64_t rend = row_off[ho+1];
col_idx += rbeg;
const int rlen = rend - rbeg;
// accumulate alpha_sum, integral, and shared stats,
// along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = col_idx[off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotk_v = __vset<FLOATV_T>(0.0f);
FLOATV_T gdotv_v = __vset<FLOATV_T>(0.0f);
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[i*BDIM_X], _kx[i*BDIM_X]));
gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[i*BDIM_X], _vx[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[NLOC_M1*BDIM_X], _kx[NLOC_M1*BDIM_X]));
gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[NLOC_M1*BDIM_X], _vx[NLOC_M1*BDIM_X]));
}
float qdotk = __vred(qdotk_v);
float gdotv = __vred(gdotv_v);
if constexpr(BDIM_X == 32) {
qdotk = __warp_sum(qdotk);
gdotv = __warp_sum(gdotv);
} else {
qdotk = __block_sum<BDIM_X>(qdotk);
gdotv = __block_sum<BDIM_X>(gdotv);
}
const float qdotk_max_tmp = max(qdotk_max, qdotk);
const float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
const float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
const float ainz_gdotv = alpha_inz * gdotv;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
const FLOATV_T kxval = _kx[i*BDIM_X];
loc_k__[i] = __vadd(__vscale(max_correction, loc_k__[i]), __vscale(alpha_inz, kxval));
loc_vw_[i] = __vadd(__vscale(max_correction, loc_vw_[i]), __vset<FLOATV_T>(ainz_gdotv));
loc_kvw[i] = __vadd(__vscale(max_correction, loc_kvw[i]), __vscale(ainz_gdotv, kxval));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
const FLOATV_T kxval = _kx[NLOC_M1*BDIM_X];
loc_k__[NLOC_M1] = __vadd(__vscale(max_correction, loc_k__[NLOC_M1]), __vscale(alpha_inz, kxval));
loc_vw_[NLOC_M1] = __vadd(__vscale(max_correction, loc_vw_[NLOC_M1]), __vset<FLOATV_T>(ainz_gdotv));
loc_kvw[NLOC_M1] = __vadd(__vscale(max_correction, loc_kvw[NLOC_M1]), __vscale(ainz_gdotv, kxval));
}
qdotk_max = qdotk_max_tmp;
}
const float alpha_sum_inv = 1.0f / alpha_sum;
integral *= alpha_sum_inv;
// Write dqy
const float alpha_sum_inv_sq = alpha_sum_inv*alpha_sum_inv;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
dqy[i*BDIM_X] = __vscale(alpha_sum_inv_sq,
__vsub(__vscale(alpha_sum, loc_kvw[i]),
__vmul(loc_vw_[i], loc_k__[i])));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
dqy[NLOC_M1*BDIM_X] = __vscale(alpha_sum_inv_sq,
__vsub(__vscale(alpha_sum, loc_kvw[NLOC_M1]),
__vmul(loc_vw_[NLOC_M1], loc_k__[NLOC_M1])));
}
// accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = col_idx[off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T qdotk_v = __vset<FLOATV_T>(0.0f);
FLOATV_T gdotv_v = __vset<FLOATV_T>(0.0f);
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[i*BDIM_X], _kx[i*BDIM_X]));
gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[i*BDIM_X], _vx[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[NLOC_M1*BDIM_X], _kx[NLOC_M1*BDIM_X]));
gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[NLOC_M1*BDIM_X], _vx[NLOC_M1*BDIM_X]));
}
float qdotk = __vred(qdotk_v);
float gdotv = __vred(gdotv_v);
if constexpr(BDIM_X == 32) {
qdotk = __warp_sum(qdotk);
gdotv = __warp_sum(gdotv);
} else {
qdotk = __block_sum<BDIM_X>(qdotk);
gdotv = __block_sum<BDIM_X>(gdotv);
}
const float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
FLOATV_T *_dkx = dkx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
FLOATV_T *_dvx = dvx + int64_t(hi)*nlon_in*nchan + int64_t(wip)*nchan;
const float alpha_mul = alpha_inz * alpha_sum_inv;
const float scale_fact_qy = (gdotv - integral)*alpha_mul;
const float scale_fact_dy = alpha_mul;
// float4, 128-bit atomics are only supported by devices of compute
// capability 9.x+, so on older devices we resort to 32-bit atomics
#if __CUDA_ARCH__ < 900
// making the loop count known at compile time doesn't seem
// to make any difference here so let's keep this (much)
// simpler version
float *sh_qy_scl = reinterpret_cast<float *>(sh_qy - tidx);
float *sh_dy_scl = reinterpret_cast<float *>(sh_dy - tidx);
float *_dkx_scl = reinterpret_cast<float *>(_dkx - tidx);
float *_dvx_scl = reinterpret_cast<float *>(_dvx - tidx);
constexpr int VEC_SIZE = sizeof(FLOATV_T)/sizeof(float);
// 32-bit, consecutive atomics to glmem
// strided atomics results in a severe slowdown
for (int chan = tidx; chan < nchan*VEC_SIZE; chan += BDIM_X) {
atomicAdd(_dkx_scl + chan, scale_fact_qy * sh_qy_scl[chan]);
atomicAdd(_dvx_scl + chan, scale_fact_dy * sh_dy_scl[chan]);
}
#else
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
atomicAdd(_dkx + i*BDIM_X, __vscale(scale_fact_qy, sh_qy[i*BDIM_X]));
atomicAdd(_dvx + i*BDIM_X, __vscale(scale_fact_dy, sh_dy[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
atomicAdd(_dkx + NLOC_M1*BDIM_X, __vscale(scale_fact_qy, sh_qy[NLOC_M1*BDIM_X]));
atomicAdd(_dvx + NLOC_M1*BDIM_X, __vscale(scale_fact_dy, sh_dy[NLOC_M1*BDIM_X]));
}
#endif
}
return;
}
template<typename FLOATV_T>
void launch_gen_attn_bwd(int batch_size,
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *_kxp,
FLOATV_T *_vxp,
FLOATV_T *_qyp,
FLOATV_T *_dyp,
int32_t *_row_idx,
int64_t *_row_off,
int64_t *_col_idx,
float *_quad_weights,
FLOATV_T *_dkxp,
FLOATV_T *_dvxp,
FLOATV_T *_dqyp,
cudaStream_t stream) {
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shsize = sizeof(FLOATV_T)*nchans*5 * block.y; // 5 arrays per warp
s2_attn_bwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx,
_quad_weights, _dkxp, _dvxp, _dqyp);
CHECK_ERROR("s2_attn_bwd_generic_vec_k");
return;
}
template<int BDIM_X,
int BDIM_Y,
int CUR_LOC_SIZE,
int MAX_LOC_SIZE, // max size of FLOATV_T[] local array
typename FLOATV_T>
void launch_spc_attn_bwd(int batch_size,
int nloc, // "BDIM_X*nloc" >= nchans
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *_kxp,
FLOATV_T *_vxp,
FLOATV_T *_qyp,
FLOATV_T *_dyp,
int32_t *_row_idx,
int64_t *_row_off,
int64_t *_col_idx,
float *_quad_weights,
FLOATV_T *_dkxp,
FLOATV_T *_dvxp,
FLOATV_T *_dqyp,
cudaStream_t stream) {
if (CUR_LOC_SIZE == nloc) {
dim3 block(BDIM_X, BDIM_Y);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shsize = sizeof(FLOATV_T)*nchans*2 * block.y; // 2 arrays per cta, block.y > 1 iif block.x==32
s2_attn_bwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx,
_quad_weights, _dkxp, _dvxp, _dqyp);
CHECK_ERROR("s2_attn_bwd_special_vec_k");
return;
}
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
launch_spc_attn_bwd<BDIM_X,
BDIM_Y,
CUR_LOC_SIZE+1,
MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights,
_dkxp, _dvxp, _dqyp, stream);
}
return;
}
static void s2_attn_bwd_dispatch(int batch_size,
int nchans,
int nlon_in,
int nlat_out,
int nlon_out,
at::Tensor kxP,
at::Tensor vxP,
at::Tensor qyP,
at::Tensor dyP,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
at::Tensor dkxP,
at::Tensor dvxP,
at::Tensor dqyP) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at::Tensor row_idx = sortRows(nlat_out, row_off, stream);
const int nlat_in = kxP.size(1);
// smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans
int bdimx;
bdimx = DIV_UP(nchans, MAX_LOCAL_ARR_LEN);
bdimx = max(bdimx, WARP_SIZE);
bdimx = next_pow2(bdimx);
float *_kxp = reinterpret_cast<float *>(kxP.data_ptr());
float *_vxp = reinterpret_cast<float *>(vxP.data_ptr());
float *_qyp = reinterpret_cast<float *>(qyP.data_ptr());
float *_dyp = reinterpret_cast<float *>(dyP.data_ptr());
float *_dkxp = reinterpret_cast<float *>(dkxP.data_ptr());
float *_dvxp = reinterpret_cast<float *>(dvxP.data_ptr());
float *_dqyp = reinterpret_cast<float *>(dqyP.data_ptr());
int32_t *_row_idx = reinterpret_cast<int32_t *>(row_idx.data_ptr());
int64_t *_row_off = reinterpret_cast<int64_t *>(row_off.data_ptr());
int64_t *_col_idx = reinterpret_cast<int64_t *>(col_idx.data_ptr());
float *_quad_weights = reinterpret_cast<float *>(quad_weights.data_ptr());
constexpr int VEC_SIZE = sizeof(float4) / sizeof(float);
if (!is_aligned<sizeof(float4)>(_kxp) ||
!is_aligned<sizeof(float4)>(_vxp) ||
!is_aligned<sizeof(float4)>(_qyp) ||
!is_aligned<sizeof(float4)>(_dyp) ||
!is_aligned<sizeof(float4)>(_dkxp) ||
!is_aligned<sizeof(float4)>(_dvxp) ||
!is_aligned<sizeof(float4)>(_dqyp) ||
(nchans % VEC_SIZE) != 0) {
const int nloc = DIV_UP(nchans, bdimx);
// to avoid the compilation of unused template instances;
// we use a block size BDIM_X that is the smallest power of 2
// such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans, so
// BDIM_X > 32 are used only for:
//
// (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchans <= BDIM_X*MAX_LOCAL_ARR_LEN
constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1;
// use 2D blocks only if 32 threads are enough; w.r.t fowrard,
// we use the special kernel only up to BDIM_X=512 as with 1024
// each thread cannot use more than 64 registers, resulting in
// large amounts of registers spills
switch(bdimx) {
case 32: launch_spc_attn_bwd< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break;
case 64: launch_spc_attn_bwd< 64, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break;
case 128: launch_spc_attn_bwd<128, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break;
case 256: launch_spc_attn_bwd<256, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break;
case 512: launch_spc_attn_bwd<512, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break;
default: launch_gen_attn_bwd (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break;
}
} else {
float4 *_kxp4 = reinterpret_cast<float4 *>(kxP.data_ptr());
float4 *_vxp4 = reinterpret_cast<float4 *>(vxP.data_ptr());
float4 *_qyp4 = reinterpret_cast<float4 *>(qyP.data_ptr());
float4 *_dyp4 = reinterpret_cast<float4 *>(dyP.data_ptr());
float4 *_dkxp4 = reinterpret_cast<float4 *>(dkxP.data_ptr());
float4 *_dvxp4 = reinterpret_cast<float4 *>(dvxP.data_ptr());
float4 *_dqyp4 = reinterpret_cast<float4 *>(dqyP.data_ptr());
nchans /= VEC_SIZE;
const int nloc = DIV_UP(nchans, bdimx);
constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE;
constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1;
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_bwd< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break;
case 64: launch_spc_attn_bwd< 64, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break;
case 128: launch_spc_attn_bwd<128, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break;
case 256: launch_spc_attn_bwd<256, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break;
case 512: launch_spc_attn_bwd<512, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break;
default: launch_gen_attn_bwd (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break;
}
}
return;
}
// END backward kernels and functions
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
......@@ -223,15 +892,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int nlon_in, int nlat_out, int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_INPUT_TENSOR(kx);
CHECK_CUDA_INPUT_TENSOR(vx);
CHECK_CUDA_INPUT_TENSOR(qy);
CHECK_CUDA_INPUT_TENSOR(dy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
// extract dtype
auto kx_type = kx.dtype();
......@@ -239,84 +909,52 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto qy_type = qy.dtype();
auto dy_type = dy.dtype();
// exract memory format
auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// convert to channels-last
auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// create output arrays
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// convert back to original dtype
dydk = dydk.to(kx_type);
dydv = dydv.to(vx_type);
dydq = dydq.to(qy_type);
// permute back to original layout
if (!kx_is_channels_last) {
dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
} else {
dydk = dydk.to(kx_type);
}
if (!vx_is_channels_last) {
dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
} else {
dydv = dydv.to(vx_type);
}
if (!qy_is_channels_last) {
dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
} else {
dydq = dydq.to(qy_type);
}
return std::make_tuple(dydk, dydv, dydq);
torch::Tensor kxP = kx.to(torch::kFloat32);
torch::Tensor vxP = vx.to(torch::kFloat32);
torch::Tensor qyP = qy.to(torch::kFloat32);
torch::Tensor dyP = dy.to(torch::kFloat32);
// exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
// the former fails for num_channels == 1
bool kx_is_channels_last = kxP.strides()[1] == 1;
bool vx_is_channels_last = vxP.strides()[1] == 1;
bool qy_is_channels_last = qyP.strides()[1] == 1;
bool dy_is_channels_last = dyP.strides()[1] == 1;
// transpose if required
if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); }
if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); }
if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); }
if (!dy_is_channels_last) { dyP = permute_4D_to0231(dyP); }
torch::Tensor dkxP = torch::zeros_like(kxP);
torch::Tensor dvxP = torch::zeros_like(vxP);
torch::Tensor dqyP = torch::zeros_like(qyP);
s2_attn_bwd_dispatch(batch_size,
uo_num_channels,
nlon_in,
nlat_out,
nlon_out,
kxP, vxP, qyP, dyP,
psi_row_off,
psi_col_idx,
quad_weights,
dkxP, dvxP, dqyP);
torch::Tensor dkx = dkxP;
torch::Tensor dvx = dvxP;
torch::Tensor dqy = dqyP;
if (!kx_is_channels_last) { dkx = permute_4D_to0312(dkx); }
if (!vx_is_channels_last) { dvx = permute_4D_to0312(dvx); }
if (!qy_is_channels_last) { dqy = permute_4D_to0312(dqy); }
// convert precision back to starting
dkx = dkx.to(kx_type);
dvx = dvx.to(vx_type);
dqy = dqy.to(qy_type);
return std::make_tuple(dkx, dvx, dqy);
// #endif
}
......@@ -39,147 +39,20 @@
#include <cub/cub.cuh>
#include <limits>
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#include "cudamacro.h"
#include "attention_utils.cuh"
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
#define THREADS (64)
#define MAX_LOCAL_ARR_LEN (16)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
// BEGIN - forward kernels and functions
template<typename VAL_T>
__device__ VAL_T __warp_sum(VAL_T val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
// called with (blockDim.x=32 and blockDim.y>1, BDIM_X=blockDim.x*blockDim.y)
template<int BDIM_X,
int BDIM_Y=1,
int BDIM_Z=1,
typename VAL_T>
__device__ VAL_T __block_sum(VAL_T val) {
const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE;
val = __warp_sum(val);
if constexpr(NWARP > 1) {
int tid = threadIdx.x;
if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; }
if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; }
const int lid = tid%WARP_SIZE;
const int wid = tid/WARP_SIZE;
__shared__ VAL_T sh[NWARP];
if (lid == 0) {
sh[wid] = val;
}
__syncthreads();
if (wid == 0) {
val = (lid < NWARP) ? sh[lid] : 0;
val = __warp_sum(val);
__syncwarp();
if (!lid) {
sh[0] = val;
}
}
__syncthreads();
val = sh[0];
__syncthreads();
}
return val;
}
template<typename FLOATV_T>
__device__ FLOATV_T __vset(float x) {}
template<>
__device__ float __forceinline__ __vset<float>(float x) {
return x;
}
__device__ float __forceinline__ __vmul(float a, float b) {
return a*b;
}
__device__ float __forceinline__ __vadd(float a, float b) {
return a+b;
}
__device__ float __forceinline__ __vred(float a) {
return a;
}
__device__ float __forceinline__ __vscale(float s, float v) {
return v*s;
}
__device__ float __forceinline__ __vdiv(float s, float v) {
return v/s;
}
template<>
__device__ float4 __forceinline__ __vset<float4>(float x) {
return make_float4(x, x, x, x);
}
__device__ float4 __forceinline__ __vmul(float4 a, float4 b) {
return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
}
__device__ float4 __forceinline__ __vadd(float4 a, float4 b) {
return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
}
__device__ float __forceinline__ __vred(float4 a) {
return a.x + a.y + a.z + a.w;
}
__device__ float4 __forceinline__ __vscale(float s, float4 v) {
return make_float4(s*v.x, s*v.y, s*v.z, s*v.w);
}
__device__ float4 __forceinline__ __vdiv(float s, float4 v) {
return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);;
}
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template<int BDIM,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM)
__launch_bounds__(BDIM_X)
void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
......@@ -188,10 +61,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor32< int, 1, torch::RestrictPtrTraits> row_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
const int32_t *__restrict__ row_idx,
const int64_t *__restrict__ row_off,
const int64_t *__restrict__ col_idx,
const float *__restrict__ quad_weights,
FLOATV_T *__restrict__ y) {
extern __shared__ __align__(sizeof(float4)) float shext[];
......@@ -225,11 +98,13 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const int64_t rbeg = row_off[ho];
const int64_t rend = row_off[ho+1];
col_idx += rbeg;
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = col_idx[rbeg+off];
const int64_t col = col_idx[off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
......@@ -273,39 +148,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
return;
}
template<typename FLOATV_T>
void launch_gen_attn_kernel(int batch_size,
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
at::Tensor row_idx,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y;
auto _row_idx = row_idx.packed_accessor32< int, 1, torch::RestrictPtrTraits>();
auto _row_off = row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _col_idx = col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32< float, 1, torch::RestrictPtrTraits>();
s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
return;
}
// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1)
template<int BDIM_X,
int BDIM_Y,
......@@ -321,10 +163,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor32< int, 1, torch::RestrictPtrTraits> row_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
const int32_t *__restrict__ row_idx,
const int64_t *__restrict__ row_off,
const int64_t *__restrict__ col_idx,
const float *__restrict__ quad_weights,
FLOATV_T *__restrict__ y) {
static_assert(0 == (BDIM_X & (BDIM_X-1)));
......@@ -375,11 +217,13 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const int64_t rbeg = row_off[ho];
const int64_t rend = row_off[ho+1];
col_idx += rbeg;
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = col_idx[rbeg+off];
const int64_t col = col_idx[off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
......@@ -442,139 +286,84 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
return;
}
template<typename FLOATV_T>
void launch_gen_attn_fwd(int batch_size,
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
int32_t *_row_idx,
int64_t *_row_off,
int64_t *_col_idx,
float *_quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y;
s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
CHECK_ERROR("s2_attn_fwd_generic_vec_k");
return;
}
template<int BDIM_X,
int BDIM_Y,
int CUR_LOC_SIZE,
int MAX_LOC_SIZE, // max size of FLOATV_T[] local array
typename FLOATV_T>
void launch_spc_attn_kernel(int batch_size,
int nloc, // "BDIM_X*nloc" >= nchans
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
at::Tensor row_idx,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
void launch_spc_attn_fwd(int batch_size,
int nloc, // "BDIM_X*nloc" >= nchans
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
int32_t *_row_idx,
int64_t *_row_off,
int64_t *_col_idx,
float *_quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
if (CUR_LOC_SIZE == nloc) {
auto _row_idx = row_idx.packed_accessor32< int, 1, torch::RestrictPtrTraits>();
auto _row_off = row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _col_idx = col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>();
dim3 block(BDIM_X, BDIM_Y);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
//printf("block: (%d, %d)\n", block.x, block.y);
//printf("grid: (%d, %d)\n", grid.x, grid.y);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y; // block.y > 1 iif block.x==32
s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
CHECK_ERROR("s2_attn_fwd_special_vec_k");
return;
}
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
launch_spc_attn_kernel<BDIM_X,
BDIM_Y,
CUR_LOC_SIZE+1,
MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp,
stream);
launch_spc_attn_fwd<BDIM_X,
BDIM_Y,
CUR_LOC_SIZE+1,
MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp,
stream);
}
return;
}
__global__ void set_rlen_rids_k(const int n,
const int64_t *__restrict__ offs,
int *__restrict__ rids,
int *__restrict__ rlen) {
const int nth = gridDim.x*blockDim.x;
const int tid = blockIdx.x*blockDim.x + threadIdx.x;
for(int i = tid; i < n; i += nth) {
rids[i] = i;
rlen[i] = offs[i+1]-offs[i];
}
return;
}
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
int64_t *_row_off_d = reinterpret_cast<int64_t *>(row_off.data_ptr());
auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device());
torch::Tensor rids_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_d = torch::empty({nlat_out}, options);
int *_rids_d = reinterpret_cast<int *>(rids_d.data_ptr());
int *_rlen_d = reinterpret_cast<int *>(rlen_d.data_ptr());
const int grid = DIV_UP(nlat_out, THREADS);
const int block = THREADS;
set_rlen_rids_k<<<grid, block, 0, stream>>>(nlat_out,
_row_off_d,
_rids_d,
_rlen_d);
torch::Tensor rids_sort_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options);
int *_rids_sort_d = reinterpret_cast<int *>(rids_sort_d.data_ptr());
int *_rlen_sort_d = reinterpret_cast<int *>(rlen_sort_d.data_ptr());
size_t temp_storage_bytes = 0;
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device());
torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options);
void *_temp_storage_d = reinterpret_cast<void *>(temp_storage_d.data_ptr());
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
return rids_sort_d;
}
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return (0 == (uintptr_t(ptr) & (ALIGN-1)));
}
static unsigned int next_pow2(unsigned int x) {
x -= 1;
#pragma unroll
for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) {
x |= x >> i;
}
return x+1;
}
static void s2_attention_dipatch(int batch_size,
static void s2_attn_fwd_dispatch(int batch_size,
int nchans,
int nlon_in,
int nlat_out,
......@@ -585,11 +374,13 @@ static void s2_attention_dipatch(int batch_size,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
at::Tensor yP,
cudaStream_t stream) {
at::Tensor yP) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at::Tensor row_idx = sortRows(nlat_out, row_off, stream);
......@@ -607,6 +398,11 @@ static void s2_attention_dipatch(int batch_size,
float *_qyp = reinterpret_cast<float *>(qyP.data_ptr());
float *_yp = reinterpret_cast<float *>(yP.data_ptr());
int32_t *_row_idx = reinterpret_cast<int32_t *>(row_idx.data_ptr());
int64_t *_row_off = reinterpret_cast<int64_t *>(row_off.data_ptr());
int64_t *_col_idx = reinterpret_cast<int64_t *>(col_idx.data_ptr());
float *_quad_weights = reinterpret_cast<float *>(quad_weights.data_ptr());
constexpr int VEC_SIZE = sizeof(float4) / sizeof(float);
if (!is_aligned<sizeof(float4)>(_kxp) ||
......@@ -616,16 +412,24 @@ static void s2_attention_dipatch(int batch_size,
(nchans % VEC_SIZE) != 0) {
const int nloc = DIV_UP(nchans, bdimx);
// to avoid the compilation of unused template instances;
// we use a block size BDIM_X that is the smallest power of 2
// such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans, so
// BDIM_X > 32 are used only for:
//
// (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchans <= BDIM_X*MAX_LOCAL_ARR_LEN
constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1;
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 32: launch_spc_attn_fwd< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
case 64: launch_spc_attn_fwd< 64, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
case 128: launch_spc_attn_fwd< 128, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
case 256: launch_spc_attn_fwd< 256, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
case 512: launch_spc_attn_fwd< 512, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_fwd<1024, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
default: launch_gen_attn_fwd (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break;
}
} else {
......@@ -638,233 +442,26 @@ static void s2_attention_dipatch(int batch_size,
nchans /= VEC_SIZE;
const int nloc = DIV_UP(nchans, bdimx);
static constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE;
constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE;
constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1;
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 32: launch_spc_attn_fwd< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
case 64: launch_spc_attn_fwd< 64, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
case 128: launch_spc_attn_fwd< 128, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
case 256: launch_spc_attn_fwd< 256, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_fwd< 512, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_fwd<1024, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
default: launch_gen_attn_fwd (batch_size, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break;
}
}
return;
}
// END - forward kernels and functions
// BEGIN - tensor permutation kernels and functions
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : 0.f;
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
__global__ void empty_k() {}
static int getPtxver() {
cudaFuncAttributes attrs;
CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k));
return attrs.ptxVersion*10;
}
static at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
}
return dst;
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : 0.f;
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
static at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
}
return dst;
}
// END - tensor permutation kernels and functions
// END - forward kernels and functions
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
......@@ -875,36 +472,37 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
int nlon_in,
int nlat_out,
int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_INPUT_TENSOR(kx);
CHECK_CUDA_INPUT_TENSOR(vx);
CHECK_CUDA_INPUT_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
torch::Tensor kxP = kx;
torch::Tensor vxP = vx;
torch::Tensor qyP = qy;
// extract dtype
auto qy_type = qy.dtype();
torch::Tensor kxP = kx.to(torch::kFloat32);
torch::Tensor vxP = vx.to(torch::kFloat32);
torch::Tensor qyP = qy.to(torch::kFloat32);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
// these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
// the former fails for num_channels == 1
bool kx_is_channels_last = kxP.strides()[1] == 1;
bool vx_is_channels_last = vxP.strides()[1] == 1;
bool qy_is_channels_last = qyP.strides()[1] == 1;
if (!k_channel_first) { kxP = permute_4D_floatT_to0231(kx, stream); }
if (!v_channel_first) { vxP = permute_4D_floatT_to0231(vx, stream); }
if (!q_channel_first) { qyP = permute_4D_floatT_to0231(qy, stream); }
if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); }
if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); }
if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); }
torch::Tensor yP = torch::empty_like(qyP);
s2_attention_dipatch(batch_size,
s2_attn_fwd_dispatch(batch_size,
uo_num_channels,
nlon_in,
nlat_out,
......@@ -913,11 +511,13 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
psi_row_off,
psi_col_idx,
quad_weights,
yP, // out tensor
stream);
yP);
torch::Tensor y = yP;
if (!q_channel_first) { y = permute_4D_floatT_to0312(yP, stream); }
if (!qy_is_channels_last) { y = permute_4D_to0312(y); }
// convert precision back to starting
y = y.to(qy_type);
C10_CUDA_KERNEL_LAUNCH_CHECK();
......
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <limits>
#include "cudamacro.h"
#include "attention_utils.cuh"
#define THREADS (64)
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
// BEGIN - CSR rows sorting kernels and functions
__global__ void set_rlen_rids_k(const int n,
const int64_t *__restrict__ offs,
int *__restrict__ rids,
int *__restrict__ rlen) {
const int nth = gridDim.x*blockDim.x;
const int tid = blockIdx.x*blockDim.x + threadIdx.x;
for(int i = tid; i < n; i += nth) {
rids[i] = i;
rlen[i] = offs[i+1]-offs[i];
}
return;
}
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
int64_t *_row_off_d = reinterpret_cast<int64_t *>(row_off.data_ptr());
auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device());
torch::Tensor rids_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_d = torch::empty({nlat_out}, options);
int *_rids_d = reinterpret_cast<int *>(rids_d.data_ptr());
int *_rlen_d = reinterpret_cast<int *>(rlen_d.data_ptr());
const int grid = DIV_UP(nlat_out, THREADS);
const int block = THREADS;
set_rlen_rids_k<<<grid, block, 0, stream>>>(nlat_out,
_row_off_d,
_rids_d,
_rlen_d);
torch::Tensor rids_sort_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options);
int *_rids_sort_d = reinterpret_cast<int *>(rids_sort_d.data_ptr());
int *_rlen_sort_d = reinterpret_cast<int *>(rlen_sort_d.data_ptr());
size_t temp_storage_bytes = 0;
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device());
torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options);
void *_temp_storage_d = reinterpret_cast<void *>(temp_storage_d.data_ptr());
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
return rids_sort_d;
}
// END - CSR rows sorting kernels and functions
// BEGIN - 4D tensor permutation kernels and functions
__global__ void empty_k() {}
static int getPtxver() {
cudaFuncAttributes attrs;
CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k));
return attrs.ptxVersion*10;
}
at::Tensor permute_4D_to0231(at::Tensor src) {
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0231_k_tile_generic");
} else {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0231_k_tile_sm100");
}
return dst;
}
at::Tensor permute_4D_to0312(at::Tensor src) {
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0312_k_tile_generic");
} else {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
}));
CHECK_ERROR("permute_to0312_k_tile_sm100");
}
return dst;
}
// END - tensor permutation kernels and functions
// BEGIN - general host-side functions
unsigned int next_pow2(unsigned int x) {
x -= 1;
#pragma unroll
for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) {
x |= x >> i;
}
return x+1;
}
// END - general host-side functions
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <ATen/ATen.h>
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
// CSR rows sorting kernels and functions
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream);
// 4D tensor permutation kernels and functions
at::Tensor permute_4D_to0231(at::Tensor src);
at::Tensor permute_4D_to0312(at::Tensor src);
// Host tensor dump and CSR manipulation functions
void dump_tensor(const char *fname, at::Tensor t);
void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols);
int part_csr_rows(int *row_perm,
const at::Tensor roff,
const at::Tensor cols,
int **part_off,
int **part_val);
int verify_part(const int npart,
const int *part_off,
const int *part_val,
const at::Tensor roff,
const at::Tensor cols);
void verify_part_new(const int nlon_out,
const int nlat_in,
const int nlon_in,
const int npart, // partitioning data
const int *part_off,
const int *part_val,
const at::Tensor roff,
const at::Tensor cols);
unsigned int next_pow2(unsigned int x);
// utility host functions and templates
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return (0 == (uintptr_t(ptr) & (ALIGN-1)));
}
// utility device functions and templates
template<typename FLOATV_T>
__device__ FLOATV_T __vset(float x) {
static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset");
return FLOATV_T{};
}
template<>
__device__ float __forceinline__ __vset<float>(float x) {
return x;
}
__device__ float __forceinline__ __vmul(float a, float b) {
return a*b;
}
__device__ float __forceinline__ __vadd(float a, float b) {
return a+b;
}
__device__ float __forceinline__ __vsub(float a, float b) {
return a-b;
}
__device__ float __forceinline__ __vred(float a) {
return a;
}
__device__ float __forceinline__ __vscale(float s, float v) {
return v*s;
}
__device__ float __forceinline__ __vdiv(float s, float v) {
return v/s;
}
template<>
__device__ float4 __forceinline__ __vset<float4>(float x) {
return make_float4(x, x, x, x);
}
__device__ float4 __forceinline__ __vmul(float4 a, float4 b) {
return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
}
__device__ float4 __forceinline__ __vadd(float4 a, float4 b) {
return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
}
__device__ float4 __forceinline__ __vsub(float4 a, float4 b) {
return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w);
}
__device__ float __forceinline__ __vred(float4 a) {
return a.x + a.y + a.z + a.w;
}
__device__ float4 __forceinline__ __vscale(float s, float4 v) {
return make_float4(s*v.x, s*v.y, s*v.z, s*v.w);
}
__device__ float4 __forceinline__ __vdiv(float s, float4 v) {
return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);;
}
template<typename VAL_T>
__device__ VAL_T __warp_sum(VAL_T val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE);
}
return val;
}
template<int BDIM_X,
int BDIM_Y=1,
int BDIM_Z=1,
typename VAL_T>
__device__ VAL_T __block_sum(VAL_T val) {
const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE;
val = __warp_sum(val);
if constexpr(NWARP > 1) {
int tid = threadIdx.x;
if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; }
if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; }
const int lid = tid%WARP_SIZE;
const int wid = tid/WARP_SIZE;
__shared__ VAL_T sh[NWARP];
if (lid == 0) {
sh[wid] = val;
}
__syncthreads();
if (wid == 0) {
val = (lid < NWARP) ? sh[lid] : 0;
val = __warp_sum(val);
__syncwarp();
if (!lid) {
sh[0] = val;
}
}
__syncthreads();
val = sh[0];
__syncthreads();
}
return val;
}
// transpose utils
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> src,
at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0);
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
template<int WARPS_X_TILE, typename VAL_T>
void launch_permute_to0231(at::Tensor src, at::Tensor dst){
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
block.y = WARPS_X_TILE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0231_k<WARP_SIZE, WARPS_X_TILE>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>(),
dst.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>());
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> src,
at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0);
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
template<int WARPS_X_TILE, typename VAL_T>
void launch_permute_to0312(at::Tensor src, at::Tensor dst){
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
block.y = WARPS_X_TILE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0312_k<WARP_SIZE, WARPS_X_TILE>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>(),
dst.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>());
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\
exit(EXIT_FAILURE); \
}}
......@@ -31,6 +31,7 @@
import torch
import torch.nn as nn
import torch.amp as amp
import torch.nn.functional as F
from typing import Optional
from abc import ABC, abstractmethod
......@@ -259,11 +260,14 @@ class W11LossS2(SphericalLossBase):
self.register_buffer("k_theta_mesh", k_theta_mesh)
def _compute_loss_term(self, prd: torch.Tensor, tar: torch.Tensor) -> torch.Tensor:
prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real
prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real
prdtype = prd.dtype
with amp.autocast(device_type="cuda", enabled=False):
prd = prd.to(torch.float32)
prd_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(prd)).real
prd_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(prd)).real
tar_prime_fft2_phi_h = torch.fft.ifft2(1j * self.k_phi_mesh * torch.fft.fft2(tar)).real
tar_prime_fft2_theta_h = torch.fft.ifft2(1j * self.k_theta_mesh * torch.fft.fft2(tar)).real
# Return the element-wise loss term
return torch.abs(prd_prime_fft2_phi_h - tar_prime_fft2_phi_h) + torch.abs(prd_prime_fft2_theta_h - tar_prime_fft2_theta_h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment