Commit c46b6925 authored by Max Rietmann's avatar Max Rietmann
Browse files

Applied new formatting

parent 1ea5c4ca
--- ---
BasedOnStyle: Webkit BasedOnStyle: Webkit
IndentWidth: 2 IndentWidth: 4
AccessModifierOffset: -2 AccessModifierOffset: -2
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
AlignTrailingComments: true AlignTrailingComments: true
...@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false ...@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon BreakConstructorInitializers: AfterColon
ColumnLimit: 120 ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2 ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 2 ContinuationIndentWidth: 4
Cpp11BracedListStyle: true Cpp11BracedListStyle: true
FixNamespaceComments: true FixNamespaceComments: true
NamespaceIndentation: All NamespaceIndentation: All
......
...@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>; ...@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32) #define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF) #define FULL_MASK (0xFFFFFFFF)
#define THREADS (64) #define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define NNZ_TRESH (32) #define NNZ_TRESH (32)
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) \
{ \ { \
cudaError_t err = call; \ cudaError_t err = call; \
if (cudaSuccess != err) { \ if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ } \
} }
#define CHECK_ERROR(errorMessage) \ #define CHECK_ERROR(errorMessage) \
{ \ { \
cudaError_t err = cudaGetLastError(); \ cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \ if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \ fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \ cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ } \
} }
static __device__ float __warp_sum(float val) static __device__ float __warp_sum(float val)
{ {
#pragma unroll #pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); } for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) static __device__ float __warp_sum_cub(float val)
{ {
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0) // 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads // 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0); sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum; return sum;
} }
// one warp per (ho,wo) // one warp per (ho,wo)
template <int BDIM_X> template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_kernel( __global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out, int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{ {
extern __shared__ float sh[]; extern __shared__ float sh[];
float *shy = sh + threadIdx.y * num_channels; float *shy = sh + threadIdx.y * num_channels;
const uint64_t batchId = blockIdx.y; const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) { return; } if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out); const int wo = wid - (ho * nlon_out);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0 #if 0
// useless read, y is always zeroed before kernel is called // useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo]; shy[chan] = y[batchId][chan][ho][wo];
#else #else
shy[chan] = 0; shy[chan] = 0;
#endif #endif
} }
float alpha_sum = 0.0f; float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho]; const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1]; const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg; const int rlen = rend - rbeg;
for (int off = 0; off < rlen; off++) { for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off]; const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in); const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f; float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
} }
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
float qdotk_max_tmp; float qdotk_max_tmp;
float alpha; float alpha;
float exp_save; float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk); qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp); exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum * exp_save; alpha_sum = alpha + alpha_sum * exp_save;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha; shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
}
qdotk_max = qdotk_max_tmp;
} }
qdotk_max = qdotk_max_tmp;
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; } for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
return; return;
} }
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
...@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, ...@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
int nlon_out) int nlon_out)
{ {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy); CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights); CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes // TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
auto k_channel_first = kx.strides()[1] == 1; auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1; auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1; auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for // transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); auto kxP = at::Tensor();
if (!k_channel_first) { if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
kxP = kx; kxP = kx;
} }
auto vxP = at::Tensor(); auto vxP = at::Tensor();
if (!v_channel_first) { if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
vxP = vx; vxP = vx;
} }
auto qyP = at::Tensor(); auto qyP = at::Tensor();
if (!q_channel_first) { if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
qyP = qy; qyP = qy;
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
nvtxRangePop(); nvtxRangePop();
torch::Tensor y = torch::empty_like(qy); torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE); dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * block.y; size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop; cudaEvent_t start, stop;
float milliseconds = 0; float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start)); CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>( s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()); quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream)); CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds); // printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout // match output layout to input layout
if (!q_channel_first) y = y.contiguous(); if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return y; return y;
} }
...@@ -33,6 +33,6 @@ ...@@ -33,6 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
} }
...@@ -37,10 +37,10 @@ ...@@ -37,10 +37,10 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) \ #define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \ CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define MIN_THREADS (64) #define MIN_THREADS (64)
#define ELXTH_MAX (32) #define ELXTH_MAX (32)
......
...@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx]; int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1]; int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff]; const int64_t ker = kers[soff];
const int64_t row = rows[soff]; const int64_t row = rows[soff];
inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi; inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi;
out += bidy * Ho * Wo; out += bidy * Ho * Wo;
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale] extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr); REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr);
// copy current inp row in regs // copy current inp row in regs
REAL_T __reg[ELXTH]; REAL_T __reg[ELXTH];
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); } for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); }
// reset shared row up to Wo+2, remaining // reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations // ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to // will be written to but never copied to
// global mem // global mem
for (int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; } for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; }
} }
__syncthreads(); __syncthreads();
int col_prev = cols[soff]; int col_prev = cols[soff];
int h_prev = col_prev / Wo; int h_prev = col_prev / Wo;
int w_prev = col_prev % Wo; int w_prev = col_prev % Wo;
// loops along the colums of CTA's row // loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) { for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz]; const int col = cols[nz];
const REAL_T val = vals[nz]; const REAL_T val = vals[nz];
// if we are processing a nz with a col value // if we are processing a nz with a col value
// leading to a new row of inp then copy it // leading to a new row of inp then copy it
// to shmem; // to shmem;
// we read a col that points to a new output // we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo) // row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev - w_prev + Wo) { if (col >= col_prev - w_prev + Wo) {
__syncthreads(); __syncthreads();
for (int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) { for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v); atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
__sh[i][j] = 0; __sh[i][j] = 0;
__sh[i][Wi + j] = 0; __sh[i][Wi + j] = 0;
} }
} }
__syncthreads(); __syncthreads();
col_prev = col; col_prev = col;
h_prev = col / Wo; h_prev = col / Wo;
w_prev = col % Wo; w_prev = col % Wo;
} }
const int w = w_prev + (col - col_prev); const int w = w_prev + (col - col_prev);
const int w_mod_ps = w % pscale; const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale; const int w_div_ps = w / pscale;
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid; const int pp = i * BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val * __reg[i]; __sh[w_mod_ps][w_div_ps + pp] += val * __reg[i];
} }
// to avoid race conditions on __sh[] // to avoid race conditions on __sh[]
// among consecutive iterations along nz // among consecutive iterations along nz
__syncthreads();
}
__syncthreads(); __syncthreads();
}
__syncthreads();
// write last row // write last row
for (int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) { for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v); atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
}
} }
} return;
return;
} }
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T> template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__ __global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, __launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff, const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
if constexpr (PSCALE != 0) { if constexpr (PSCALE != 0) {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out);
} else { } else {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
} }
return; return;
} }
template <int NTH, int ELXTH, typename REAL_T> template <int NTH, int ELXTH, typename REAL_T>
...@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t ...@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t stream) cudaStream_t stream)
{ {
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8); static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) { if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wi) { if (NTH * ELXTH >= Wi) {
dim3 grid(nrows, BC); dim3 grid(nrows, BC);
const int pscale = Wo / Wi; const int pscale = Wo / Wi;
size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale); size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale);
switch (pscale) { switch (pscale) {
case 1: case 1:
disco_bwd_blk_k<NTH, ELXTH, 1> disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
case 2: case 2:
disco_bwd_blk_k<NTH, ELXTH, 2> disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
case 3: case 3:
disco_bwd_blk_k<NTH, ELXTH, 3> disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
default: default:
disco_bwd_blk_k<NTH, ELXTH, 0> disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
} }
} else { } else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
stream); out_d, stream);
}
} }
} return;
return;
} }
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{ {
// some sanity checks // some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp); CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val); CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes // extract some shapes
int64_t B = inp.size(0); int64_t B = inp.size(0);
int64_t C = inp.size(1); int64_t C = inp.size(1);
int64_t BC = B * C; int64_t BC = B * C;
int64_t Hi = inp.size(3); int64_t Hi = inp.size(3);
int64_t Wi = inp.size(4); int64_t Wi = inp.size(4);
int64_t nrows = roff_idx.size(0) - 1; int64_t nrows = roff_idx.size(0) - 1;
// allocate output // allocate output
int64_t out_dims[] = {B, C, Ho, Wo}; int64_t out_dims[] = {B, C, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options); torch::Tensor out = torch::zeros(out_dims, options);
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert // assert
static_assert(0 == (ELXTH_MAX % 2)); static_assert(0 == (ELXTH_MAX % 2));
if (Wo <= 64 * ELXTH_MAX) { if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>( launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
})); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} else if (Wo <= 128 * ELXTH_MAX) { }));
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { } else if (Wo <= 128 * ELXTH_MAX) {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
})); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
} else if (Wo <= 256 * ELXTH_MAX) { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { }));
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( } else if (Wo <= 256 * ELXTH_MAX) {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
})); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
} else if (Wo <= 512 * ELXTH_MAX) { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( }));
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), } else if (Wo <= 512 * ELXTH_MAX) {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
})); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
} else if (Wo <= 1024 * ELXTH_MAX) { ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), }));
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), } else if (Wo <= 1024 * ELXTH_MAX) {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
})); launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
} else { BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
1024 * ELXTH_MAX); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
exit(EXIT_FAILURE); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} }));
} else {
return out; fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
} }
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx]; int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1]; int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff]; const int64_t ker = kers[soff];
const int64_t row = rows[soff]; const int64_t row = rows[soff];
inp += bidy * Hi * Wi; inp += bidy * Hi * Wi;
out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo; out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo;
REAL_T __reg[ELXTH] = {0}; REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] extern __shared__ __align__(
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr); sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff]; int col_prev = cols[soff];
int h_prev = col_prev / Wi; int h_prev = col_prev / Wi;
int w_prev = col_prev % Wi; int w_prev = col_prev % Wi;
// copy current inp row in shmem // copy current inp row in shmem
for (int i = tid; i < Wi; i += BDIM_X) { for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i]; const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v; __sh[i] = v;
__sh[Wi + i] = v; __sh[Wi + i] = v;
} }
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used // locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads(); __syncthreads();
// loops along the colums of CTA's row // loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) { for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz]; const int col = cols[nz];
const REAL_T val = vals[nz]; const REAL_T val = vals[nz];
// if we are processing a nz with a col value // if we are processing a nz with a col value
// leading to a new row of inp then copy it // leading to a new row of inp then copy it
// to shmem; // to shmem;
// checks whether (h_prev < h) with: // checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi) // (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev - w_prev + Wi) { if (col >= col_prev - w_prev + Wi) {
col_prev = col; col_prev = col;
h_prev = col / Wi; h_prev = col / Wi;
w_prev = col % Wi; w_prev = col % Wi;
__syncthreads(); __syncthreads();
for (int i = tid; i < Wi; i += BDIM_X) { for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i]; const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v; __sh[i] = v;
__sh[Wi + i] = v; __sh[Wi + i] = v;
} }
__syncthreads(); __syncthreads();
} }
const int w = w_prev + (col - col_prev); const int w = w_prev + (col - col_prev);
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid; const int pp = i * BDIM_X + tid;
// original lines: // original lines:
// //
// if (pp >= Wo) break; // if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi; // const int wpp = (w + pscale*pp) % Wi;
// //
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi // value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem, // so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod // replicate the current inp row and avoid the costly mod
// //
// also, to avoid the conditional, sh can be extended to // also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop // cover the maximum location accessed during this loop
// //
// REAL_T __sh[2*Wi + ppscale*NUM_REM] // REAL_T __sh[2*Wi + ppscale*NUM_REM]
// //
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) = // Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) = // = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = 2*Wi + ppscale*NUM_REM // = 2*Wi + ppscale*NUM_REM
// //
// with NUM_REM = BDIM_X*ELXTH - Wo // with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale * pp; const int wpp = w + pscale * pp;
__reg[i] += val * __sh[wpp]; __reg[i] += val * __sh[wpp];
}
} }
}
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid; const int pp = i * BDIM_X + tid;
if (pp >= Wo) break; if (pp >= Wo) break;
out[pp] = __reg[i]; out[pp] = __reg[i];
} }
return; return;
} }
template <int BDIM_X, int ELXTH, typename REAL_T> template <int BDIM_X, int ELXTH, typename REAL_T>
__global__ __global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, __launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff, const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
return; return;
} }
template <int NTH, int ELXTH, typename REAL_T> template <int NTH, int ELXTH, typename REAL_T>
...@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t ...@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t stream) cudaStream_t stream)
{ {
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8); static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) { if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wo) { if (NTH * ELXTH >= Wo) {
dim3 grid(nrows, BC); dim3 grid(nrows, BC);
const int pscale = Wi / Wo; const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo)); size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH> disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); col_d, val_d, inp_d, out_d);
} else { } else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
stream); out_d, stream);
}
} }
} return;
return;
} }
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{ {
// some sanity checks // some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp); CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val); CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes // extract some shapes
int64_t B = inp.size(0); int64_t B = inp.size(0);
int64_t C = inp.size(1); int64_t C = inp.size(1);
int64_t BC = B * C; int64_t BC = B * C;
int64_t Hi = inp.size(2); int64_t Hi = inp.size(2);
int64_t Wi = inp.size(3); int64_t Wi = inp.size(3);
int64_t nrows = roff_idx.size(0) - 1; int64_t nrows = roff_idx.size(0) - 1;
// allocate output // allocate output
int64_t out_dims[] = {B, C, K, Ho, Wo}; int64_t out_dims[] = {B, C, K, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options); torch::Tensor out = torch::zeros(out_dims, options);
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert // assert
static_assert(0 == (ELXTH_MAX % 2)); static_assert(0 == (ELXTH_MAX % 2));
// pick the correct launch config // pick the correct launch config
if (Wo <= 64 * ELXTH_MAX) { if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>( launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
})); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} else if (Wo <= 128 * ELXTH_MAX) { }));
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { } else if (Wo <= 128 * ELXTH_MAX) {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
})); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
} else if (Wo <= 256 * ELXTH_MAX) { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { }));
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( } else if (Wo <= 256 * ELXTH_MAX) {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
})); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
} else if (Wo <= 512 * ELXTH_MAX) { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( }));
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), } else if (Wo <= 512 * ELXTH_MAX) {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
})); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
} else if (Wo <= 1024 * ELXTH_MAX) { ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), }));
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), } else if (Wo <= 1024 * ELXTH_MAX) {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
})); launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
} else { BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
1024 * ELXTH_MAX); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
exit(EXIT_FAILURE); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} }));
} else {
return out; fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
} }
...@@ -33,6 +33,6 @@ ...@@ -33,6 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)"); m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)"); m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment