Commit 73bfdc53 authored by Mauro Bisson's avatar Mauro Bisson
Browse files

Optimize FWD kernel: reduced tail effect

* Added a new CSR array, psi_row_index, containing "ho" values sorted in descending order of CSR row length; this is used to process (ho, wo) points corresponding to longer rows before shorter ones, improving overlap and reducing the tail effect.
parent 8cb399ee
...@@ -48,7 +48,6 @@ ...@@ -48,7 +48,6 @@
#define TRANSP_WARPS_X_TILE_SM100 (4) #define TRANSP_WARPS_X_TILE_SM100 (4)
#define MAX_LOCAL_ARR_LEN (16) #define MAX_LOCAL_ARR_LEN (16)
#define NEXT_POW2(x) (1u << (8*sizeof(x)-__builtin_clz(x-1)))
#define CHECK_CUDA(call) { \ #define CHECK_CUDA(call) { \
cudaError_t err = call; \ cudaError_t err = call; \
...@@ -68,6 +67,7 @@ ...@@ -68,6 +67,7 @@
// BEGIN - forward kernels and functions // BEGIN - forward kernels and functions
template<typename VAL_T> template<typename VAL_T>
__device__ VAL_T __warp_sum(VAL_T val) { __device__ VAL_T __warp_sum(VAL_T val) {
...@@ -79,16 +79,23 @@ __device__ VAL_T __warp_sum(VAL_T val) { ...@@ -79,16 +79,23 @@ __device__ VAL_T __warp_sum(VAL_T val) {
} }
template<int BDIM_X, template<int BDIM_X,
int BDIM_Y=1,
int BDIM_Z=1,
typename VAL_T> typename VAL_T>
__device__ VAL_T __block_sum(VAL_T val) { __device__ VAL_T __block_sum(VAL_T val) {
const int NWARP = BDIM_X/WARP_SIZE; const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE;
val = __warp_sum(val); val = __warp_sum(val);
if constexpr(NWARP > 1) { if constexpr(NWARP > 1) {
const int lid = threadIdx.x%WARP_SIZE;
const int wid = threadIdx.x/WARP_SIZE; int tid = threadIdx.x;
if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; }
if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; }
const int lid = tid%WARP_SIZE;
const int wid = tid/WARP_SIZE;
__shared__ VAL_T sh[NWARP]; __shared__ VAL_T sh[NWARP];
...@@ -143,7 +150,6 @@ __device__ float __forceinline__ __vdiv(float s, float v) { ...@@ -143,7 +150,6 @@ __device__ float __forceinline__ __vdiv(float s, float v) {
return v/s; return v/s;
} }
template<> template<>
__device__ float4 __forceinline__ __vset<float4>(float x) { __device__ float4 __forceinline__ __vset<float4>(float x) {
return make_float4(x, x, x, x); return make_float4(x, x, x, x);
...@@ -169,13 +175,6 @@ __device__ float4 __forceinline__ __vdiv(float s, float4 v) { ...@@ -169,13 +175,6 @@ __device__ float4 __forceinline__ __vdiv(float s, float4 v) {
return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);;
} }
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return 0 == (uintptr_t(ptr) & (ALIGN-1));
}
// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y) // called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y)
template<int BDIM, template<int BDIM,
typename FLOATV_T> // either float or float4 typename FLOATV_T> // either float or float4
...@@ -189,14 +188,14 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -189,14 +188,14 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const FLOATV_T *__restrict__ kx, const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx, const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy, const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_off, const torch::PackedTensorAccessor32< int, 1, torch::RestrictPtrTraits> row_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights, const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
FLOATV_T *__restrict__ y) { FLOATV_T *__restrict__ y) {
extern __shared__ __align__(sizeof(float4)) float shext[];
extern __shared__ __align__(sizeof(float4)) float sh[]; FLOATV_T *shy = reinterpret_cast<FLOATV_T *>(shext) + threadIdx.y*nchan;
FLOATV_T *shy = reinterpret_cast<FLOATV_T *>(sh) + threadIdx.y*nchan;
const int batch = blockIdx.y; const int batch = blockIdx.y;
const int wid = blockIdx.x*blockDim.y + threadIdx.y; const int wid = blockIdx.x*blockDim.y + threadIdx.y;
...@@ -207,8 +206,9 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -207,8 +206,9 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const int h = wid / nlon_out;
const int wo = wid - (ho*nlon_out); const int wo = wid - (h*nlon_out);
const int ho = row_idx[h];
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) { for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
shy[chan] = __vset<FLOATV_T>(0.f); shy[chan] = __vset<FLOATV_T>(0.f);
...@@ -222,14 +222,14 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -222,14 +222,14 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
float alpha_sum = 0.0f; float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_off[ho]; const int64_t rbeg = row_off[ho];
const int64_t rend = psi_row_off[ho+1]; const int64_t rend = row_off[ho+1];
const int rlen = rend-rbeg; const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) { for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off]; const int64_t col = col_idx[rbeg+off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in); const int wi = col - (hi*nlon_in);
...@@ -265,19 +265,11 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha ...@@ -265,19 +265,11 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
qdotk_max = qdotk_max_tmp; qdotk_max = qdotk_max_tmp;
} }
// alpha should be reciprocated here and then multiplied
// but for now I'm keeping the div branch the same output
// as my older versions
#if 0
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
y[chan] = __vdiv(alpha_sum, shy[chan]);
}
#else
alpha_sum = 1.0f / alpha_sum; alpha_sum = 1.0f / alpha_sum;
for(int chan = tidx; chan < nchan; chan += WARP_SIZE) { for(int chan = tidx; chan < nchan; chan += WARP_SIZE) {
y[chan] = __vscale(alpha_sum, shy[chan]); y[chan] = __vscale(alpha_sum, shy[chan]);
} }
#endif
return; return;
} }
...@@ -292,8 +284,9 @@ void launch_gen_attn_kernel(int batch_size, ...@@ -292,8 +284,9 @@ void launch_gen_attn_kernel(int batch_size,
FLOATV_T *__restrict__ _kxp, FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp, FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp, FLOATV_T *__restrict__ _qyp,
at::Tensor psi_row_off, at::Tensor row_idx,
at::Tensor psi_col_idx, at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights, at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp, FLOATV_T *__restrict__ _yp,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -303,14 +296,14 @@ void launch_gen_attn_kernel(int batch_size, ...@@ -303,14 +296,14 @@ void launch_gen_attn_kernel(int batch_size,
size_t shsize = sizeof(FLOATV_T)*nchans * block.y; size_t shsize = sizeof(FLOATV_T)*nchans * block.y;
auto _psi_row_off = psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(); auto _row_idx = col_idx.packed_accessor32< int, 1, torch::RestrictPtrTraits>();
auto _psi_col_idx = psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(); auto _row_off = row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _col_idx = col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32< float, 1, torch::RestrictPtrTraits>(); auto _quad_weights = quad_weights.packed_accessor32< float, 1, torch::RestrictPtrTraits>();
s2_attn_fwd_generic_vec_k<THREADS> s2_attn_fwd_generic_vec_k<THREADS>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, <<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _psi_row_off, _psi_col_idx, _quad_weights, _yp); _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
return; return;
} }
...@@ -329,11 +322,14 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -329,11 +322,14 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const FLOATV_T *__restrict__ kx, const FLOATV_T *__restrict__ kx,
const FLOATV_T *__restrict__ vx, const FLOATV_T *__restrict__ vx,
const FLOATV_T *__restrict__ qy, const FLOATV_T *__restrict__ qy,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_off, const torch::PackedTensorAccessor32< int, 1, torch::RestrictPtrTraits> row_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> row_off,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> col_idx,
const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights, const torch::PackedTensorAccessor32< float, 1, torch::RestrictPtrTraits> quad_weights,
FLOATV_T *__restrict__ y) { FLOATV_T *__restrict__ y) {
static_assert(0 == (BDIM_X & (BDIM_X-1)));
static_assert(0 == (BDIM_Y & (BDIM_Y-1)));
static_assert((BDIM_X == 32 && BDIM_Y > 1) || static_assert((BDIM_X == 32 && BDIM_Y > 1) ||
(BDIM_X > 32 && BDIM_Y == 1)) ; (BDIM_X > 32 && BDIM_Y == 1)) ;
...@@ -349,11 +345,12 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -349,11 +345,12 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
FLOATV_T locy[NLOC]; FLOATV_T locy[NLOC];
extern __shared__ __align__(sizeof(float4)) float sh[]; extern __shared__ __align__(sizeof(float4)) float shext[];
FLOATV_T *shq = reinterpret_cast<FLOATV_T *>(sh) + threadIdx.y*nchan + tidx; FLOATV_T *shq = reinterpret_cast<FLOATV_T *>(shext) + threadIdx.y*nchan + tidx;
const int ho = ctaid / nlon_out; const int h = ctaid / nlon_out;
const int wo = ctaid - (ho*nlon_out); const int wo = ctaid - (h*nlon_out);
const int ho = row_idx[h];
kx += batch*nlat_in*nlon_in*nchan + tidx; kx += batch*nlat_in*nlon_in*nchan + tidx;
vx += batch*nlat_in*nlon_in*nchan + tidx; vx += batch*nlat_in*nlon_in*nchan + tidx;
...@@ -376,14 +373,14 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -376,14 +373,14 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
float alpha_sum = 0.0f; float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_off[ho]; const int64_t rbeg = row_off[ho];
const int64_t rend = psi_row_off[ho+1]; const int64_t rend = row_off[ho+1];
const int rlen = rend-rbeg; const int rlen = rend-rbeg;
for(int off = 0; off < rlen; off++) { for(int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off]; const int64_t col = col_idx[rbeg+off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in); const int wi = col - (hi*nlon_in);
...@@ -406,8 +403,9 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -406,8 +403,9 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
_kx[NLOC_M1*BDIM_X])); _kx[NLOC_M1*BDIM_X]));
} }
float qdotk = __block_sum<BDIM_X>(__vred(qdotkv)); float qdotk = __vred(qdotkv);
if constexpr(BDIM_X == 32) { qdotk = __warp_sum(qdotk); }
else { qdotk = __block_sum<BDIM_X>(qdotk); }
float qdotk_max_tmp; float qdotk_max_tmp;
float alpha; float alpha;
...@@ -431,15 +429,7 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -431,15 +429,7 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
qdotk_max = qdotk_max_tmp; qdotk_max = qdotk_max_tmp;
} }
#if 0
#pragma unroll
for(int i = 0; i < NLOC_M1; i++) {
y[i*BDIM_X] = __vdiv(alpha_sum, locy[i]);
}
if (NLOC_M1*BDIM_X+tidx < nchan) {
y[NLOC_M1*BDIM_X] = __vdiv(alpha_sum, locy[NLOC_M1]);
}
#else
alpha_sum = 1.0f / alpha_sum; alpha_sum = 1.0f / alpha_sum;
#pragma unroll #pragma unroll
...@@ -449,7 +439,7 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan ...@@ -449,7 +439,7 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
if (NLOC_M1*BDIM_X+tidx < nchan) { if (NLOC_M1*BDIM_X+tidx < nchan) {
y[NLOC_M1*BDIM_X] = __vscale(alpha_sum, locy[NLOC_M1]); y[NLOC_M1*BDIM_X] = __vscale(alpha_sum, locy[NLOC_M1]);
} }
#endif
return; return;
} }
...@@ -468,16 +458,18 @@ void launch_spc_attn_kernel(int batch_size, ...@@ -468,16 +458,18 @@ void launch_spc_attn_kernel(int batch_size,
FLOATV_T *__restrict__ _kxp, FLOATV_T *__restrict__ _kxp,
FLOATV_T *__restrict__ _vxp, FLOATV_T *__restrict__ _vxp,
FLOATV_T *__restrict__ _qyp, FLOATV_T *__restrict__ _qyp,
at::Tensor psi_row_off, at::Tensor row_idx,
at::Tensor psi_col_idx, at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights, at::Tensor quad_weights,
FLOATV_T *__restrict__ _yp, FLOATV_T *__restrict__ _yp,
cudaStream_t stream) { cudaStream_t stream) {
if (CUR_LOC_SIZE == nloc) { if (CUR_LOC_SIZE == nloc) {
auto _psi_row_off = psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(); auto _row_idx = row_idx.packed_accessor32< int, 1, torch::RestrictPtrTraits>();
auto _psi_col_idx = psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(); auto _row_off = row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _col_idx = col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>();
auto _quad_weights = quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>(); auto _quad_weights = quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>();
dim3 block(BDIM_X, BDIM_Y); dim3 block(BDIM_X, BDIM_Y);
...@@ -490,7 +482,7 @@ void launch_spc_attn_kernel(int batch_size, ...@@ -490,7 +482,7 @@ void launch_spc_attn_kernel(int batch_size,
s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE> s2_attn_fwd_special_vec_k<BDIM_X, BDIM_Y, CUR_LOC_SIZE>
<<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out, <<<grid, block, shsize, stream>>>(nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, _psi_row_off, _psi_col_idx, _quad_weights, _yp); _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp);
return; return;
} }
if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) {
...@@ -498,13 +490,92 @@ void launch_spc_attn_kernel(int batch_size, ...@@ -498,13 +490,92 @@ void launch_spc_attn_kernel(int batch_size,
BDIM_Y, BDIM_Y,
CUR_LOC_SIZE+1, CUR_LOC_SIZE+1,
MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, MAX_LOC_SIZE>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out,
_kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp,
stream); stream);
} }
return; return;
} }
void s2_attention_dipatch(int batch_size, __global__ void set_rlen_rids_k(const int n,
const int64_t *__restrict__ offs,
int *__restrict__ rids,
int *__restrict__ rlen) {
const int nth = gridDim.x*blockDim.x;
const int tid = blockIdx.x*blockDim.x + threadIdx.x;
for(int i = tid; i < n; i += nth) {
rids[i] = i;
rlen[i] = offs[i+1]-offs[i];
}
return;
}
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
int64_t *_row_off_d = reinterpret_cast<int64_t *>(row_off.data_ptr());
auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device());
torch::Tensor rids_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_d = torch::empty({nlat_out}, options);
int *_rids_d = reinterpret_cast<int *>(rids_d.data_ptr());
int *_rlen_d = reinterpret_cast<int *>(rlen_d.data_ptr());
const int grid = DIV_UP(nlat_out, THREADS);
const int block = THREADS;
set_rlen_rids_k<<<grid, block, 0, stream>>>(nlat_out,
_row_off_d,
_rids_d,
_rlen_d);
torch::Tensor rids_sort_d = torch::empty({nlat_out}, options);
torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options);
int *_rids_sort_d = reinterpret_cast<int *>(rids_sort_d.data_ptr());
int *_rlen_sort_d = reinterpret_cast<int *>(rlen_sort_d.data_ptr());
size_t temp_storage_bytes = 0;
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device());
torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options);
void *_temp_storage_d = reinterpret_cast<void *>(temp_storage_d.data_ptr());
CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes,
_rlen_d, _rlen_sort_d,
_rids_d, _rids_sort_d,
nlat_out, 0, sizeof(*_rlen_d)*8, stream));
return rids_sort_d;
}
template<unsigned int ALIGN>
int is_aligned(const void *ptr) {
static_assert(0 == (ALIGN & (ALIGN-1)));
return (0 == (uintptr_t(ptr) & (ALIGN-1)));
}
static unsigned int next_pow2(unsigned int x) {
x -= 1;
#pragma unroll
for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) {
x |= x >> i;
}
return x+1;
}
static void s2_attention_dipatch(int batch_size,
int nchans, int nchans,
int nlon_in, int nlon_in,
int nlat_out, int nlat_out,
...@@ -512,21 +583,25 @@ void s2_attention_dipatch(int batch_size, ...@@ -512,21 +583,25 @@ void s2_attention_dipatch(int batch_size,
at::Tensor kxP, at::Tensor kxP,
at::Tensor vxP, at::Tensor vxP,
at::Tensor qyP, at::Tensor qyP,
at::Tensor psi_row_off, at::Tensor row_off,
at::Tensor psi_col_idx, at::Tensor col_idx,
at::Tensor quad_weights, at::Tensor quad_weights,
at::Tensor yP, at::Tensor yP,
cudaStream_t stream) { cudaStream_t stream) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at::Tensor row_idx = sortRows(nlat_out, row_off, stream);
const int nlat_in = kxP.size(1); const int nlat_in = kxP.size(1);
// smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans // smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans
int bdimx; int bdimx;
bdimx = DIV_UP(nchans, MAX_LOCAL_ARR_LEN); bdimx = DIV_UP(nchans, MAX_LOCAL_ARR_LEN);
bdimx = max(bdimx, WARP_SIZE); bdimx = max(bdimx, WARP_SIZE);
bdimx = NEXT_POW2(bdimx); bdimx = next_pow2(bdimx);
float *_kxp = reinterpret_cast<float *>(kxP.data_ptr()); float *_kxp = reinterpret_cast<float *>(kxP.data_ptr());
float *_vxp = reinterpret_cast<float *>(vxP.data_ptr()); float *_vxp = reinterpret_cast<float *>(vxP.data_ptr());
...@@ -545,15 +620,17 @@ void s2_attention_dipatch(int batch_size, ...@@ -545,15 +620,17 @@ void s2_attention_dipatch(int batch_size,
// use 2D blocks only if 32 threads are enough // use 2D blocks only if 32 threads are enough
switch(bdimx) { switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_ARR_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, psi_row_off, psi_col_idx, quad_weights, _yp, stream); break; default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, row_idx, row_off, col_idx, quad_weights, _yp, stream); break;
} }
} else { } else {
float4 *_kxp4 = reinterpret_cast<float4 *>(_kxp); float4 *_kxp4 = reinterpret_cast<float4 *>(_kxp);
float4 *_vxp4 = reinterpret_cast<float4 *>(_vxp); float4 *_vxp4 = reinterpret_cast<float4 *>(_vxp);
float4 *_qyp4 = reinterpret_cast<float4 *>(_qyp); float4 *_qyp4 = reinterpret_cast<float4 *>(_qyp);
...@@ -566,13 +643,13 @@ void s2_attention_dipatch(int batch_size, ...@@ -566,13 +643,13 @@ void s2_attention_dipatch(int batch_size,
// use 2D blocks only if 32 threads are enough // use 2D blocks only if 32 threads are enough
switch(bdimx) { switch(bdimx) {
case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; case 32: launch_spc_attn_kernel< 32, 2, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; case 64: launch_spc_attn_kernel< 64, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; case 128: launch_spc_attn_kernel< 128, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; case 256: launch_spc_attn_kernel< 256, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; case 512: launch_spc_attn_kernel< 512, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; case 1024: launch_spc_attn_kernel<1024, 1, 1, MAX_LOCAL_VEC_LEN>(batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, psi_row_off, psi_col_idx, quad_weights, _yp4, stream); break; default: launch_gen_attn_kernel (batch_size, nloc, nchans, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, row_idx, row_off, col_idx, quad_weights, _yp4, stream); break;
} }
} }
...@@ -638,9 +715,7 @@ void permute_to0231_k(const int nchn, ...@@ -638,9 +715,7 @@ void permute_to0231_k(const int nchn,
} }
} }
} }
return; return;
} }
__global__ void empty_k() {} __global__ void empty_k() {}
...@@ -651,7 +726,7 @@ static int getPtxver() { ...@@ -651,7 +726,7 @@ static int getPtxver() {
return attrs.ptxVersion*10; return attrs.ptxVersion*10;
} }
at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) { static at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
...@@ -748,12 +823,10 @@ void permute_to0312_k(const int nchn, ...@@ -748,12 +823,10 @@ void permute_to0312_k(const int nchn,
} }
} }
} }
return; return;
} }
at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) { static at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
...@@ -803,7 +876,6 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -803,7 +876,6 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
int nlon_in, int nlon_in,
int nlat_out, int nlat_out,
int nlon_out) { int nlon_out) {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy); CHECK_CUDA_TENSOR(qy);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment