Commit 8cb399ee authored by Mauro Bisson's avatar Mauro Bisson
Browse files

Optimized FWD kernel: custom permutations, gmem accesses reduction, vectorized access

* Replaced PyTorch's slow permutation ops with custom kernels, significantly improving performance (especially on GB200).
* Split kernel into general and specialized versions for num_channel <= 16384, significantly reducing memory accesses.
* Enabled float4-based vectorized memory access when pointer alignment and channel size allow, improving throughput.
* Added runtime dispatch logic for kernel specialization.
parent c485a1fb
......@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
......@@ -39,217 +39,816 @@
#include <cub/cub.cuh>
#include <limits>
using BlockReduceFloat256 = cub::BlockReduce<float, 256>;
using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
static __device__ float __warp_sum(float val)
{
#pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
// one warp per (ho,wo)
template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[];
float *shy = sh + threadIdx.y * num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define TRANSP_WARPS_X_TILE_GENERIC (32)
#define TRANSP_WARPS_X_TILE_SM100 (4)
#define MAX_LOCAL_ARR_LEN (16)
#define NEXT_POW2(x) (1u << (8*sizeof(x)-__builtin_clz(x-1)))
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
// BEGIN - forward kernels and functions
template<typename VAL_T>
__device__ VAL_T __warp_sum(VAL_T val) {
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
}
template<int BDIM_X,
typename VAL_T>
__device__ VAL_T __block_sum(VAL_T val) {
const int NWARP = BDIM_X/WARP_SIZE;
val = __warp_sum(val);
if constexpr(NWARP > 1) {
const int lid = threadIdx.x%WARP_SIZE;
const int wid = threadIdx.x/WARP_SIZE;
__shared__ VAL_T sh[NWARP];
if (lid == 0) {
sh[wid] = val;
}
__syncthreads();
if (wid == 0) {
val = (lid < NWARP) ? sh[lid] : 0;
val = __warp_sum(val);
__syncwarp();
if (!lid) {
sh[0] = val;
}
}
__syncthreads();
val = sh[0];
__syncthreads();
}
return val;
}
template<typename FLOATV_T>
__device__ FLOATV_T __vset(float x) {}
template<>
__device__ float __forceinline__ __vset<float>(float x) {
return x;
}
__device__ float __forceinline__ __vmul(float a, float b) {
return a*b;
}
__device__ float __forceinline__ __vadd(float a, float b) {
return a+b;
}
__device__ float __forceinline__ __vred(float a) {
return a;
}
__device__ float __forceinline__ __vscale(float s, float v) {
return v*s;
}
__device__ float __forceinline__ __vdiv(float s, float v) {
return v/s;
}
template<>
__device__ float4 __forceinline__ __vset<float4>(float x) {
return make_float4(x, x, x, x);
}
__device__ float4 __forceinline__ __vmul(float4 a, float4 b) {
return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w);
}
__device__ float4 __forceinline__ __vadd(float4 a, float4 b) {
return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
}
__device__ float __forceinline__ __vred(float4 a) {
return a.x + a.y + a.z + a.w;
}
__device__ float4 __forceinline__ __vscale(float s, float4 v) {
return make_float4(s*v.x, s*v.y, s*v.z, s*v.w);
}
__device__ float4 __forceinline__ __vdiv(float s, float4 v) {
return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);;
}
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return 0 == (uintptr_t(ptr) & (ALIGN-1));
}
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template<int BDIM,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM)
void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
FLOATV_T *__restrict__ y) {
extern __shared__ __align__(sizeof(float4)) float sh[];
FLOATV_T *shy = reinterpret_cast<FLOATV_T *>(sh) + threadIdx.y*nchan;
const int batch = blockIdx.y;
const int wid = blockIdx.x*blockDim.y + threadIdx.y;
if (wid >= nlat_out*nlon_out) {
return;
}
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out);
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
shy[chan] = __vset<FLOATV_T>(0.f);
}
kx += batch*nlat_in*nlon_in*nchan;
vx += batch*nlat_in*nlon_in*nchan;
qy += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan;
y += batch*nlat_out*nlon_out*nchan + ho*nchan*nlon_out + wo*nchan;
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_off[ho];
const int64_t rend = psi_row_off[ho+1];
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
qdotkv = __vadd(qdotkv,
__vmul( qy[chan],
_kx[chan]));
}
float qdotk = __warp_sum(__vred(qdotkv));
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
shy[chan] = __vadd(__vscale(exp_save, shy[chan]),
__vscale( alpha, _vx[chan]));
}
qdotk_max = qdotk_max_tmp;
}
// alpha should be reciprocated here and then multiplied
// but for now I'm keeping the div branch the same output
// as my older versions
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
y[chan] = __vdiv(alpha_sum, shy[chan]);
}
#else
shy[chan] = 0;
alpha_sum = 1.0f / alpha_sum;
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
y[chan] = __vscale(alpha_sum, shy[chan]);
}
#endif
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
return;
}
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1];
template<typename FLOATV_T>
void launch_gen_attn_kernel(int batch_size,
int nloc,
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
at::Tensor psi_row_off,
at::Tensor psi_col_idx,
at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y;
auto _psi_row_off = psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _psi_col_idx = psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32< float, 1, torch::RestrictPtrTraits>();
s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _psi_row_off, _psi_col_idx, _quad_weights, _yp);
return;
}
const int rlen = rend - rbeg;
// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1)
template<int BDIM_X,
int BDIM_Y,
int NLOC,
typename FLOATV_T> // either float or float4
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along channel dim
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
FLOATV_T *__restrict__ y) {
static_assert((BDIM_X == 32 && BDIM_Y > 1) ||
(BDIM_X > 32 && BDIM_Y == 1)) ;
constexpr int NLOC_M1 = NLOC-1;
const int tidx = threadIdx.x;
const int batch = blockIdx.y;
const int ctaid = blockIdx.x*blockDim.y + threadIdx.y;
if (ctaid >= nlat_out*nlon_out) {
return;
}
for (int off = 0; off < rlen; off++) {
FLOATV_T locy[NLOC];
const int64_t col = psi_col_idx[rbeg + off];
extern __shared__ __align__(sizeof(float4)) float sh[];
FLOATV_T *shq = reinterpret_cast<FLOATV_T *>(sh) + threadIdx.y*nchan + tidx;
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
const int ho = ctaid / nlon_out;
const int wo = ctaid - (ho*nlon_out);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
kx += batch*nlat_in*nlon_in*nchan + tidx;
vx += batch*nlat_in*nlon_in*nchan + tidx;
qy += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx;
y += batch*nlat_out*nlon_out*nchan + ho*nlon_out*nchan + wo*nchan + tidx;
alpha_sum = alpha + alpha_sum * exp_save;
#pragma unroll
for(int i = 0; i < NLOC; i++) {
locy[i] = __vset<FLOATV_T>(0.f);
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
shq[i*BDIM_X] = qy[i*BDIM_X];
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
shq[NLOC_M1*BDIM_X] = qy[NLOC_M1*BDIM_X];
}
qdotk_max = qdotk_max_tmp;
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_off[ho];
const int64_t rend = psi_row_off[ho+1];
const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const FLOATV_T *_kx = kx + hi*nlon_in*nchan + wip*nchan;
const FLOATV_T *_vx = vx + hi*nlon_in*nchan + wip*nchan;
FLOATV_T qdotkv = __vset<FLOATV_T>(0.f);
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
qdotkv = __vadd(qdotkv,
__vmul(shq[i*BDIM_X],
_kx[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
qdotkv = __vadd(qdotkv,
__vmul(shq[NLOC_M1*BDIM_X],
_kx[NLOC_M1*BDIM_X]));
}
float qdotk = __block_sum<BDIM_X>(__vred(qdotkv));
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
locy[i] = __vadd(__vscale(exp_save, locy[i]),
__vscale(alpha, _vx[i*BDIM_X]));
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
locy[NLOC_M1] = __vadd(__vscale(exp_save, locy[NLOC_M1]),
__vscale(alpha, _vx[NLOC_M1*BDIM_X]));
}
qdotk_max = qdotk_max_tmp;
}
#if 0
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
y[i*BDIM_X] = __vdiv(alpha_sum, locy[i]);
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
y[NLOC_M1*BDIM_X] = __vdiv(alpha_sum, locy[NLOC_M1]);
}
#else
alpha_sum = 1.0f / alpha_sum;
return;
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
y[i*BDIM_X] = __vscale(alpha_sum, locy[i]);
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
y[NLOC_M1*BDIM_X] = __vscale(alpha_sum, locy[NLOC_M1]);
}
#endif
return;
}
template<int BDIM_X,
int BDIM_Y,
int CUR_LOC_SIZE,
int MAX_LOC_SIZE, // max size of FLOATV_T[] local array
typename FLOATV_T>
void launch_spc_attn_kernel(int batch_size,
int nloc, // "BDIM_X*nloc" >= nchans
int nchans,
int nlat_in,
int nlon_in,
int nlat_out,
int nlon_out,
FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp,
at::Tensor psi_row_off,
at::Tensor psi_col_idx,
at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp,
cudaStream_t stream) {
if (CUR_LOC_SIZE == nloc) {
auto _psi_row_off = psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _psi_col_idx = psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>();
dim3 block(BDIM_X, BDIM_Y);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
//printf("block: (%d, %d)\n", block.x, block.y);
//printf("grid: (%d, %d)\n", grid.x, grid.y);
size_t shsize = sizeof(FLOATV_T)*nchans * block.y; // block.y > 1 iif block.x==32
s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _psi_row_off, _psi_col_idx, _quad_weights, _yp);
return;
}
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
launch_spc_attn_kernel<BDIM_X,
BDIM_Y,
CUR_LOC_SIZE+1,
MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp,
stream);
}
return;
}
void s2_attention_dipatch(int batch_size,
int nchans,
int nlon_in,
int nlat_out,
int nlon_out,
at::Tensor kxP,
at::Tensor vxP,
at::Tensor qyP,
at::Tensor psi_row_off,
at::Tensor psi_col_idx,
at::Tensor quad_weights,
at::Tensor yP,
cudaStream_t stream) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
const int nlat_in = kxP.size(1);
// smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans
int bdimx;
bdimx = DIV_UP(nchans, MAX_LOCAL_ARR_LEN);
bdimx = max(bdimx, WARP_SIZE);
bdimx = NEXT_POW2(bdimx);
float *_kxp = reinterpret_cast<float *>(kxP.data_ptr());
float *_vxp = reinterpret_cast<float *>(vxP.data_ptr());
float *_qyp = reinterpret_cast<float *>(qyP.data_ptr());
float *_yp = reinterpret_cast<float *>(yP.data_ptr());
constexpr int VEC_SIZE = sizeof(float4) / sizeof(float);
if (!is_aligned<sizeof(float4)>(_kxp) ||
!is_aligned<sizeof(float4)>(_vxp) ||
!is_aligned<sizeof(float4)>(_qyp) ||
!is_aligned<sizeof(float4)>(_yp) ||
(nchans % VEC_SIZE) != 0) {
const int nloc = DIV_UP(nchans, bdimx);
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break;
}
} else {
float4 *_kxp4 = reinterpret_cast<float4 *>(_kxp);
float4 *_vxp4 = reinterpret_cast<float4 *>(_vxp);
float4 *_qyp4 = reinterpret_cast<float4 *>(_qyp);
float4 *_yp4 = reinterpret_cast<float4 *>(_yp);
nchans /= VEC_SIZE;
const int nloc = DIV_UP(nchans, bdimx);
static constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE;
// use 2D blocks only if 32 threads are enough
switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break;
}
}
return;
}
// END - forward kernels and functions
// BEGIN - tensor permutation kernels and functions
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : 0.f;
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
__global__ void empty_k() {}
static int getPtxver() {
cudaFuncAttributes attrs;
CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k));
return attrs.ptxVersion*10;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
int nlon_out)
{
at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
}
return dst;
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : 0.f;
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
}
return dst;
}
// END - tensor permutation kernels and functions
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in,
int nlat_out,
int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
auto stream = at::cuda::getCurrentCUDAStream().stream();
// TODO: check sizes
size_t uo_num_channels = kx.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int batch_size = kx.size(0);
size_t uo_num_channels = kx.size(1);
torch::Tensor kxP = kx;
torch::Tensor vxP = vx;
torch::Tensor qyP = qy;
const int batch_size = kx.size(0);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
if (!k_channel_first) { kxP = permute_4D_floatT_to0231(kx, stream); }
if (!v_channel_first) { vxP = permute_4D_floatT_to0231(vx, stream); }
if (!q_channel_first) { qyP = permute_4D_floatT_to0231(qy, stream); }
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
torch::Tensor yP = torch::empty_like(qyP);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
s2_attention_dipatch(batch_size,
uo_num_channels,
nlon_in,
nlat_out,
nlon_out,
kxP, vxP, qyP,
psi_row_off,
psi_col_idx,
quad_weights,
yP, // out tensor
stream);
size_t shared_size = sizeof(float) * uo_num_channels * block.y;
torch::Tensor y = yP;
if (!q_channel_first) { y = permute_4D_floatT_to0312(yP, stream); }
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
C10_CUDA_KERNEL_LAUNCH_CHECK();
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK();
return y;
return y;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment