Commit 45fc2a46 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

cleanup with contiguous checks

parent 51200bda
......@@ -78,7 +78,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 1, 1, (2, 4), (2, 4), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 4, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
],
skip_on_empty=True,
......@@ -156,8 +157,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[
# Format: [batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol]
[4, 4, 1, (6, 12), (6, 12), "equiangular", "equiangular", 1e-2, 0],
# [4, 4, 2, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
],
......
......@@ -520,6 +520,16 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
B, _, H, W = grad_output.shape
grad_output = grad_output.reshape(B*nh, -1, H, W)
# save type and convert to float32
kw_dtype = kw.dtype
vw_dtype = vw.dtype
qw_dtype = qw.dtype
kw = kw.to(torch.float32).contiguous()
vw = vw.to(torch.float32).contiguous()
qw = qw.to(torch.float32).contiguous()
grad_output = grad_output.to(torch.float32).contiguous()
dkw,dvw,dqw = attention_cuda_extension.backward_dkvq(kw, vw, qw, grad_output,
quad_weights,
col_idx, row_off,
......@@ -533,6 +543,11 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_, C, H, W = dqw.shape
dqw = dqw.reshape(B, -1, H, W)
# convert precision
dkw = dkw.to(dtype=kw_dtype)
dvw = dvw.to(dtype=vw_dtype)
dqw = dqw.to(dtype=qw_dtype)
# input grads
dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
......
......@@ -34,7 +34,11 @@
#include <cstdint>
#include <torch/torch.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA)
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast))
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
......
......@@ -785,11 +785,13 @@ static void s2_attn_bwd_dispatch(int batch_size,
at::Tensor quad_weights,
at::Tensor dkxP,
at::Tensor dvxP,
at::Tensor dqyP,
cudaStream_t stream) {
at::Tensor dqyP) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at::Tensor row_idx = sortRows(nlat_out, row_off, stream);
......@@ -890,122 +892,129 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int nlon_in, int nlat_out, int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_INPUT_TENSOR(kx);
CHECK_CUDA_INPUT_TENSOR(vx);
CHECK_CUDA_INPUT_TENSOR(qy);
CHECK_CUDA_INPUT_TENSOR(dy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
#if 0
// #if 0
// // extract dtype
// auto kx_type = kx.dtype();
// auto vx_type = vx.dtype();
// auto qy_type = qy.dtype();
// auto dy_type = dy.dtype();
// // exract memory format
// auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// // convert to channels-last
// auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// // create output arrays
// auto dydk = torch::zeros_like(qyP);
// auto dydv = torch::zeros_like(qyP);
// auto dydq = torch::zeros_like(qyP);
// size_t uo_num_channels = kx.size(1);
// const int batch_size = kx.size(0);
// dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
// dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
// size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
// cudaEvent_t start, stop;
// float milliseconds = 0;
// CHECK_CUDA(cudaEventCreate(&start));
// CHECK_CUDA(cudaEventCreate(&stop));
// CHECK_CUDA(cudaEventRecord(start, stream));
// s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
// uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
// CHECK_CUDA(cudaEventRecord(stop, stream));
// CHECK_CUDA(cudaEventSynchronize(stop));
// CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// CHECK_CUDA(cudaEventDestroy(start));
// CHECK_CUDA(cudaEventDestroy(stop));
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// // Permute outputs back to memory layout given by input. if input had channels
// // first, leave it in that layout, otherwise permute layout back to [batch,
// // channel, ho, wo]
// // convert back to original dtype
// dydk = dydk.to(kx_type);
// dydv = dydv.to(vx_type);
// dydq = dydq.to(qy_type);
// // permute back to original layout
// if (!kx_is_channels_last) {
// dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydk = dydk.to(kx_type);
// }
// if (!vx_is_channels_last) {
// dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydv = dydv.to(vx_type);
// }
// if (!qy_is_channels_last) {
// dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydq = dydq.to(qy_type);
// }
// return std::make_tuple(dydk, dydv, dydq);
// #else
const size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
// extract dtype
auto kx_type = kx.dtype();
auto vx_type = vx.dtype();
auto qy_type = qy.dtype();
auto dy_type = dy.dtype();
// exract memory format
auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
torch::Tensor kxP = kx.to(torch::kFloat32);
torch::Tensor vxP = vx.to(torch::kFloat32);
torch::Tensor qyP = qy.to(torch::kFloat32);
torch::Tensor dyP = dy.to(torch::kFloat32);
// convert to channels-last
auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
// the former fails for num_channels == 1
bool kx_is_channels_last = kxP.strides()[1] == 1;
bool vx_is_channels_last = vxP.strides()[1] == 1;
bool qy_is_channels_last = qyP.strides()[1] == 1;
bool dy_is_channels_last = dyP.strides()[1] == 1;
// create output arrays
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// convert back to original dtype
dydk = dydk.to(kx_type);
dydv = dydv.to(vx_type);
dydq = dydq.to(qy_type);
// permute back to original layout
if (!kx_is_channels_last) {
dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
} else {
dydk = dydk.to(kx_type);
}
if (!vx_is_channels_last) {
dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
} else {
dydv = dydv.to(vx_type);
}
if (!qy_is_channels_last) {
dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
} else {
dydq = dydq.to(qy_type);
}
return std::make_tuple(dydk, dydv, dydq);
#else
const size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
torch::Tensor kxP = kx;
torch::Tensor vxP = vx;
torch::Tensor qyP = qy;
torch::Tensor dyP = dy;
auto kx_channel_first = kx.strides()[1] == 1;
auto vx_channel_first = vx.strides()[1] == 1;
auto qy_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
if (!kx_channel_first) { kxP = permute_4D_floatT_to0231(kx, stream); }
if (!vx_channel_first) { vxP = permute_4D_floatT_to0231(vx, stream); }
if (!qy_channel_first) { qyP = permute_4D_floatT_to0231(qy, stream); }
if (!dy_channel_first) { dyP = permute_4D_floatT_to0231(dy, stream); }
// transpose if required
if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); }
if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); }
if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); }
if (!dy_is_channels_last) { dyP = permute_4D_to0231(dyP); }
torch::Tensor dkxP = torch::zeros_like(kxP);
torch::Tensor dvxP = torch::zeros_like(vxP);
......@@ -1020,17 +1029,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
psi_row_off,
psi_col_idx,
quad_weights,
dkxP, dvxP, dqyP, // out tensors
stream);
dkxP, dvxP, dqyP);
torch::Tensor dkx = dkxP;
torch::Tensor dvx = dvxP;
torch::Tensor dqy = dqyP;
if (!kx_channel_first) { dkx = permute_4D_floatT_to0312(dkxP, stream); }
if (!vx_channel_first) { dvx = permute_4D_floatT_to0312(dvxP, stream); }
if (!qy_channel_first) { dqy = permute_4D_floatT_to0312(dqyP, stream); }
if (!kx_is_channels_last) { dkx = permute_4D_to0312(dkx); }
if (!vx_is_channels_last) { dvx = permute_4D_to0312(dvx); }
if (!qy_is_channels_last) { dqy = permute_4D_to0312(dqy); }
// convert precision back to starting
dkx = dkx.to(kx_type);
dvx = dvx.to(vx_type);
dqy = dqy.to(qy_type);
return std::make_tuple(dkx, dvx, dqy);
#endif
// #endif
}
......@@ -374,11 +374,13 @@ static void s2_attn_fwd_dispatch(int batch_size,
at::Tensor row_off,
at::Tensor col_idx,
at::Tensor quad_weights,
at::Tensor yP,
cudaStream_t stream) {
at::Tensor yP) {
static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1)));
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// sort row indices (ho-s) in descending order
// based on (row_off[ho+1]-row_off[ho])
at::Tensor row_idx = sortRows(nlat_out, row_off, stream);
......@@ -470,32 +472,33 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
int nlon_in,
int nlat_out,
int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_INPUT_TENSOR(kx);
CHECK_CUDA_INPUT_TENSOR(vx);
CHECK_CUDA_INPUT_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
torch::Tensor kxP = kx;
torch::Tensor vxP = vx;
torch::Tensor qyP = qy;
// extract dtype
auto qy_type = qy.dtype();
torch::Tensor kxP = kx.to(torch::kFloat32);
torch::Tensor vxP = vx.to(torch::kFloat32);
torch::Tensor qyP = qy.to(torch::kFloat32);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
// these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast)
// the former fails for num_channels == 1
bool kx_is_channels_last = kxP.strides()[1] == 1;
bool vx_is_channels_last = vxP.strides()[1] == 1;
bool qy_is_channels_last = qyP.strides()[1] == 1;
if (!k_channel_first) { kxP = permute_4D_floatT_to0231(kx, stream); }
if (!v_channel_first) { vxP = permute_4D_floatT_to0231(vx, stream); }
if (!q_channel_first) { qyP = permute_4D_floatT_to0231(qy, stream); }
if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); }
if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); }
if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); }
torch::Tensor yP = torch::empty_like(qyP);
......@@ -508,11 +511,13 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
psi_row_off,
psi_col_idx,
quad_weights,
yP, // out tensor
stream);
yP);
torch::Tensor y = yP;
if (!q_channel_first) { y = permute_4D_floatT_to0312(yP, stream); }
if (!qy_is_channels_last) { y = permute_4D_to0312(y); }
// convert precision back to starting
y = y.to(qy_type);
C10_CUDA_KERNEL_LAUNCH_CHECK();
......
......@@ -111,66 +111,6 @@ at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) {
// BEGIN - 4D tensor permutation kernels and functions
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : 0.f;
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
__global__ void empty_k() {}
static int getPtxver() {
......@@ -179,144 +119,96 @@ static int getPtxver() {
return attrs.ptxVersion*10;
}
at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
at::Tensor permute_4D_to0231(at::Tensor src) {
dim3 block;
dim3 grid;
//dim3 block;
//dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
//block.x = WARP_SIZE;
//grid.x = DIV_UP(src.size(1), block.x);
//grid.y = DIV_UP(src.size(3), block.x);
//grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
//assert(grid.y < 65536);
//assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
}));
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
// <<<grid, block, 0, stream>>>(src.size(1),
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_generic");
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
}));
//block.y = TRANSP_WARPS_X_TILE_SM100;
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
// <<<grid, block, 0, stream>>>(src.size(1),
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_sm100");
}
return dst;
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> src,
torch::PackedTensorAccessor32<VAL_T, 4, torch::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : 0.f;
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
at::Tensor permute_4D_to0312(at::Tensor src) {
dim3 block;
dim3 grid;
//dim3 block;
//dim3 grid;
block.x = WARP_SIZE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
//block.x = WARP_SIZE;
//grid.x = DIV_UP(src.size(2), block.x);
//grid.y = DIV_UP(src.size(3), block.x);
//grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
//assert(grid.y < 65536);
//assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src.device());
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
const int ptxv = getPtxver();
// to be further specialized for additional archs, if necessary
if (ptxv < 100) {
block.y = TRANSP_WARPS_X_TILE_GENERIC;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
}));
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
// <<<grid, block, 0, stream>>>(src.size(3),
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_generic");
} else {
block.y = TRANSP_WARPS_X_TILE_SM100;
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
}));
//block.y = TRANSP_WARPS_X_TILE_SM100;
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
// <<<grid, block, 0, stream>>>(src.size(3),
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_sm100");
}
......
......@@ -40,8 +40,8 @@
at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream);
// 4D tensor permutation kernels and functions
at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream);
at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream);
at::Tensor permute_4D_to0231(at::Tensor src);
at::Tensor permute_4D_to0312(at::Tensor src);
// Host tensor dump and CSR manipulation functions
void dump_tensor(const char *fname, at::Tensor t);
......@@ -200,3 +200,174 @@ __device__ VAL_T __block_sum(VAL_T val) {
}
return val;
}
// transpose utils
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0231_k(const int nchn,
const int nlat,
const int nlon,
const at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> src,
at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int coff = blockIdx.x*BDIM_X; // channel offset
const int woff = blockIdx.y*BDIM_X; // width offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
} else {
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0);
}
}
__syncthreads();
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (woff + j+tidy < nlon) {
dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy];
}
}
}
}
return;
}
template<int TRANSP_WARPS_X_TILE_SIZE, typename VAL_T>
void launch_permute_to0231(at::Tensor src, at::Tensor dst){
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
block.y = TRANSP_WARPS_X_TILE_SIZE;
grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SIZE>
<<<grid, block, 0, stream>>>(src.size(1),
src.size(2),
src.size(3),
src.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>(),
dst.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>());
}
template<int BDIM_X,
int BDIM_Y,
typename VAL_T>
__global__
__launch_bounds__(BDIM_X*BDIM_Y)
void permute_to0312_k(const int nchn,
const int nlat,
const int nlon,
const at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> src,
at::PackedTensorAccessor32<VAL_T, 4, at::RestrictPtrTraits> dst) {
static_assert(!(BDIM_X & (BDIM_X-1)));
static_assert(!(BDIM_Y & (BDIM_Y-1)));
static_assert(BDIM_X >= BDIM_Y);
__shared__ VAL_T sh[BDIM_X][BDIM_X+1];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int woff = blockIdx.x*BDIM_X; // width offset
const int coff = blockIdx.y*BDIM_X; // channel offset
const int batch = blockIdx.z / nlat; // batch (same for all block)
const int h = blockIdx.z - (batch * nlat); // height (same for all block)
const int nchn_full = (nchn-coff) >= BDIM_X;
const int nlon_full = (nlon-woff) >= BDIM_X;
if (nchn_full && nlon_full) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx];
}
__syncthreads();
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];
}
} else {
if (coff+tidx < nchn) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0);
}
}
__syncthreads();
if (woff+tidx < nlon) {
#pragma unroll
for(int j = 0; j < BDIM_X; j += BDIM_Y) {
if (coff + j+tidy < nchn) {
dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];;
}
}
}
}
return;
}
template<int TRANSP_WARPS_X_TILE_SIZE, typename VAL_T>
void launch_permute_to0312(at::Tensor src, at::Tensor dst){
dim3 block;
dim3 grid;
block.x = WARP_SIZE;
block.y = TRANSP_WARPS_X_TILE_SIZE;
grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0);
assert(grid.y < 65536);
assert(grid.z < 65536);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SIZE>
<<<grid, block, 0, stream>>>(src.size(3),
src.size(1),
src.size(2),
src.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>(),
dst.packed_accessor32<VAL_T, 4, at::RestrictPtrTraits>());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment