Commit 373f9b0b authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

formatting

parent ebc122eb
......@@ -36,16 +36,11 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
int nlon_out);
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
......@@ -51,28 +51,32 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
}}
} \
}
#endif
#include <iostream>
#include <chrono>
#include <string>
class ScopeTimer {
class ScopeTimer
{
public:
explicit ScopeTimer(const std::string& label = "")
: label_(label), start_(std::chrono::high_resolution_clock::now()) {}
explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
~ScopeTimer() {
~ScopeTimer()
{
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
......@@ -83,17 +87,16 @@ private:
std::chrono::high_resolution_clock::time_point start_;
};
static __device__ float __warp_sum(float val) {
static __device__ float __warp_sum(float val)
{
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
......@@ -108,14 +111,9 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel(
int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
......@@ -125,14 +123,15 @@ __launch_bounds__(BDIM_X)
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float *sh_alpha_vw = sh_alpha_k + num_channels;
float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels;
float* sh_qy = sh_dy + num_channels;
float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
......@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
// First pass: find qdotk_max
......@@ -166,9 +165,7 @@ __launch_bounds__(BDIM_X)
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; }
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
......@@ -201,7 +198,8 @@ __launch_bounds__(BDIM_X)
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
......@@ -227,16 +225,11 @@ __launch_bounds__(BDIM_X)
}
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
......@@ -257,7 +250,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
......@@ -300,8 +293,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
......@@ -310,10 +303,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<
grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
......@@ -339,9 +330,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if(!k_channel_first) dydk = dydk.contiguous();
if(!v_channel_first) dydv = dydv.contiguous();
if(!q_channel_first) dydq = dydq.contiguous();
if (!k_channel_first) dydk = dydk.contiguous();
if (!v_channel_first) dydv = dydv.contiguous();
if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
......@@ -352,6 +343,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq);
}
......@@ -45,36 +45,39 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
}}
} \
}
#define CHECK_ERROR(errorMessage) { \
#define CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
}}
} \
}
static __device__ float __warp_sum(float val) {
static __device__ float __warp_sum(float val)
{
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
......@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
return sum;
}
// one warp per (ho,wo)
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_kernel(int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[];
float *shy = sh + threadIdx.y*num_channels;
float *shy = sh + threadIdx.y * num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out);
const int wo = wid - (ho * nlon_out);
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
......@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend-rbeg;
const int rlen = rend - rbeg;
for(int off = 0; off < rlen; off++) {
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off];
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][ wo]*
kx[batchId][chan][hi][wip];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
......@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save;
alpha_sum = alpha + alpha_sum * exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][chan][hi][wip]*alpha;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
}
qdotk_max = qdotk_max_tmp;
}
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
return;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in,
int nlat_out,
int nlon_out) {
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
......@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
......@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float)*uo_num_channels * block.y;
size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
......@@ -243,9 +230,8 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS>
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
......@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
return y;
}
......@@ -31,8 +31,8 @@
#include "attention.cuh"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
}
......@@ -36,32 +36,19 @@
#include <c10/cuda/CUDAStream.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) CHECK_CUDA_TENSOR(x); CHECK_CONTIGUOUS_TENSOR(x)
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
// forward kernel
torch::Tensor disco_cuda_fwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo);
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo);
// backward kernel
torch::Tensor disco_cuda_bwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo);
\ No newline at end of file
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo);
......@@ -31,23 +31,12 @@
#include "disco.h"
#include "disco_cuda.cuh"
template<int BDIM_X,
int ELXTH,
typename REAL_T>
__device__ void disco_bwd_d(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
template <int BDIM_X, int ELXTH, typename REAL_T>
__device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale,
const int64_t *__restrict__ roff, const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows, const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
const int tid = threadIdx.x;
......@@ -55,36 +44,32 @@ __device__ void disco_bwd_d(const int Hi,
const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx+1];
int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
inp += bidy*K*Hi*Wi + ker*Hi*Wi + row*Wi;
out += bidy*Ho*Wo;
inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi;
out += bidy * Ho * Wo;
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T (*__sh)[BDIM_X*ELXTH*2] = reinterpret_cast<REAL_T (*)[BDIM_X*ELXTH*2]>(__sh_ptr);
REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr);
// copy current inp row in regs
REAL_T __reg[ELXTH];
#pragma unroll
for(int i = 0; i < ELXTH; i++) {
__reg[i] = (i*BDIM_X+tid < Wi) ? inp[i*BDIM_X +tid] : REAL_T(0);
}
#pragma unroll
for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); }
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// global mem
for(int i = 0; i < pscale; i++) {
#pragma unroll
for(int j = 0; j < 2*BDIM_X*ELXTH; j += BDIM_X) {
__sh[i][j+tid] = 0;
}
for (int i = 0; i < pscale; i++) {
#pragma unroll
for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; }
}
__syncthreads();
......@@ -94,7 +79,7 @@ __device__ void disco_bwd_d(const int Hi,
int w_prev = col_prev % Wo;
// loops along the colums of CTA's row
for(int64_t nz = soff; nz < eoff; nz++) {
for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz];
const REAL_T val = vals[nz];
......@@ -104,16 +89,16 @@ __device__ void disco_bwd_d(const int Hi,
// to shmem;
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev-w_prev+Wo) {
if (col >= col_prev - w_prev + Wo) {
__syncthreads();
for(int i = 0; i < pscale; i++) {
for(int j = tid; j < Wi; j += BDIM_X) {
for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev*Wo + j*pscale + i], v);
atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
__sh[i][ j] = 0;
__sh[i][j] = 0;
__sh[i][Wi + j] = 0;
}
}
......@@ -124,15 +109,15 @@ __device__ void disco_bwd_d(const int Hi,
w_prev = col % Wo;
}
const int w = w_prev + (col-col_prev);
const int w = w_prev + (col - col_prev);
const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale;
#pragma unroll
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val*__reg[i];
const int pp = i * BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val * __reg[i];
}
// to avoid race conditions on __sh[]
......@@ -142,127 +127,78 @@ __device__ void disco_bwd_d(const int Hi,
__syncthreads();
// write last row
for(int i = 0; i < pscale; i++) {
for (int i = 0; i < pscale; i++) {
for(int j = tid; j < Wi; j += BDIM_X) {
for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev*Wo + j*pscale + i], v);
atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
}
}
return;
}
template<int BDIM_X,
int ELXTH,
int PSCALE,
typename REAL_T>
__global__ __launch_bounds__(BDIM_X)
void disco_bwd_blk_k(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
if constexpr(PSCALE != 0) { disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); }
else { disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); }
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
if constexpr (PSCALE != 0) {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out);
} else {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
}
return;
}
template <int NTH, int ELXTH, typename REAL_T>
static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t nrows, int64_t *roff_d, int64_t *ker_d,
int64_t *row_d, int64_t *col_d, REAL_T *val_d, REAL_T *inp_d, REAL_T *out_d,
cudaStream_t stream)
{
template<int NTH,
int ELXTH,
typename REAL_T>
static void launch_kernel(int BC,
int Hi,
int Wi,
int K,
int Ho,
int Wo,
int64_t nrows,
int64_t *roff_d,
int64_t *ker_d,
int64_t *row_d,
int64_t *col_d,
REAL_T *val_d,
REAL_T *inp_d,
REAL_T *out_d,
cudaStream_t stream) {
static_assert(sizeof(REAL_T) == 2 ||
sizeof(REAL_T) == 4 ||
sizeof(REAL_T) == 8);
if constexpr(ELXTH <= ELXTH_MAX) {
if (NTH*ELXTH >= Wi) {
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wi) {
dim3 grid(nrows, BC);
const int pscale = Wo/Wi;
size_t shmem = sizeof(*out_d)*(2 * (NTH*ELXTH)*pscale);
const int pscale = Wo / Wi;
size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale);
switch(pscale) {
switch (pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 1>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 2>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 3>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
disco_bwd_blk_k<NTH, ELXTH, 0>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH+1>(BC,
Hi, Wi,
K, Ho, Wo,
nrows,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d,
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
}
}
return;
}
torch::Tensor disco_cuda_bwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo) {
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
......@@ -289,85 +225,52 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX%2));
static_assert(0 == (ELXTH_MAX % 2));
if (Wo <= 64*ELXTH_MAX) {
if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 128*ELXTH_MAX) {
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 256*ELXTH_MAX) {
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 512*ELXTH_MAX) {
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 1024*ELXTH_MAX) {
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else {
fprintf(stderr,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d\n",
__FILE__, __LINE__, Wo, 1024*ELXTH_MAX);
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
}
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
//}
......@@ -31,23 +31,12 @@
#include "disco.h"
#include "disco_cuda.cuh"
template<int BDIM_X,
int ELXTH,
typename REAL_T>
__device__ void disco_fwd_d(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
template <int BDIM_X, int ELXTH, typename REAL_T>
__device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale,
const int64_t *__restrict__ roff, const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows, const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
const int tid = threadIdx.x;
......@@ -55,13 +44,13 @@ __device__ void disco_fwd_d(const int Hi,
const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx+1];
int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
inp += bidy*Hi*Wi;
out += bidy*K*Ho*Wo + ker*Ho*Wo + row*Wo;
inp += bidy * Hi * Wi;
out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo;
REAL_T __reg[ELXTH] = {0};
......@@ -75,16 +64,16 @@ __device__ void disco_fwd_d(const int Hi,
int w_prev = col_prev % Wi;
// copy current inp row in shmem
for(int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev*Wi + i];
__sh[ i] = v;
for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v;
__sh[Wi + i] = v;
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads();
// loops along the colums of CTA's row
for(int64_t nz = soff; nz < eoff; nz++) {
for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz];
const REAL_T val = vals[nz];
......@@ -94,27 +83,27 @@ __device__ void disco_fwd_d(const int Hi,
// to shmem;
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev-w_prev+Wi) {
if (col >= col_prev - w_prev + Wi) {
col_prev = col;
h_prev = col / Wi;
w_prev = col % Wi;
__syncthreads();
for(int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev*Wi + i];
__sh[ i] = v;
for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v;
__sh[Wi + i] = v;
}
__syncthreads();
}
const int w = w_prev + (col-col_prev);
const int w = w_prev + (col - col_prev);
#pragma unroll
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid;
const int pp = i * BDIM_X + tid;
// original lines:
//
......@@ -136,17 +125,16 @@ __device__ void disco_fwd_d(const int Hi,
//
// with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale*pp;
__reg[i] += val*__sh[wpp];
const int wpp = w + pscale * pp;
__reg[i] += val * __sh[wpp];
}
}
#pragma unroll
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid;
const int pp = i * BDIM_X + tid;
if (pp >= Wo) break;
out[pp] = __reg[i];
......@@ -155,92 +143,48 @@ __device__ void disco_fwd_d(const int Hi,
return;
}
template<int BDIM_X,
int ELXTH,
typename REAL_T>
__global__ __launch_bounds__(BDIM_X)
void disco_fwd_blk_k(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
template <int BDIM_X, int ELXTH, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
return;
}
template <int NTH, int ELXTH, typename REAL_T>
static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t nrows, int64_t *roff_d, int64_t *ker_d,
int64_t *row_d, int64_t *col_d, REAL_T *val_d, REAL_T *inp_d, REAL_T *out_d,
cudaStream_t stream)
{
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
template<int NTH,
int ELXTH,
typename REAL_T>
static void launch_kernel(int BC,
int Hi,
int Wi,
int K,
int Ho,
int Wo,
int64_t nrows,
int64_t *roff_d,
int64_t *ker_d,
int64_t *row_d,
int64_t *col_d,
REAL_T *val_d,
REAL_T *inp_d,
REAL_T *out_d,
cudaStream_t stream) {
static_assert(sizeof(REAL_T) == 2 ||
sizeof(REAL_T) == 4 ||
sizeof(REAL_T) == 8);
if constexpr(ELXTH <= ELXTH_MAX) {
if (NTH*ELXTH >= Wo) {
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wo) {
dim3 grid(nrows, BC);
const int pscale = Wi/Wo;
size_t shmem = sizeof(*out_d)*(Wi*2 + pscale*(NTH*ELXTH-Wo));
const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
disco_fwd_blk_k<NTH, ELXTH>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH+1>(BC,
Hi, Wi,
K, Ho, Wo,
nrows,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d,
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
}
}
return;
}
torch::Tensor disco_cuda_fwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo) {
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
......@@ -267,81 +211,49 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX%2));
static_assert(0 == (ELXTH_MAX % 2));
// pick the correct launch config
if (Wo <= 64*ELXTH_MAX) {
if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 128*ELXTH_MAX) {
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 256*ELXTH_MAX) {
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 512*ELXTH_MAX) {
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else if (Wo <= 1024*ELXTH_MAX) {
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
}
else {
fprintf(stderr,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d\n",
__FILE__, __LINE__, Wo, 1024*ELXTH_MAX);
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
}
......@@ -30,31 +30,21 @@
#include "disco.h"
template<typename REAL_T>
void preprocess_psi_kernel(int64_t nnz,
int64_t K,
int64_t Ho,
int64_t *ker_h,
int64_t *row_h,
int64_t *col_h,
int64_t *roff_h,
REAL_T *val_h,
int64_t& nrows) {
template <typename REAL_T>
void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, int64_t *row_h, int64_t *col_h,
int64_t *roff_h, REAL_T *val_h, int64_t &nrows)
{
int64_t *Koff = new int64_t[K];
for(int i = 0; i < K; i++) {
Koff[i] = 0;
}
for (int i = 0; i < K; i++) { Koff[i] = 0; }
for(int64_t i = 0; i < nnz; i++) {
Koff[ker_h[i]]++;
}
for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }
int64_t prev = Koff[0];
Koff[0] = 0;
for(int i = 1; i < K; i++) {
for (int i = 1; i < K; i++) {
int64_t save = Koff[i];
Koff[i] = prev + Koff[i-1];
Koff[i] = prev + Koff[i - 1];
prev = save;
}
......@@ -63,7 +53,7 @@ void preprocess_psi_kernel(int64_t nnz,
int64_t *col_sort = new int64_t[nnz];
float *val_sort = new float[nnz];
for(int64_t i = 0; i < nnz; i++) {
for (int64_t i = 0; i < nnz; i++) {
const int64_t ker = ker_h[i];
const int64_t off = Koff[ker]++;
......@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
col_sort[off] = col_h[i];
val_sort[off] = val_h[i];
}
for(int64_t i = 0; i < nnz; i++) {
for (int64_t i = 0; i < nnz; i++) {
ker_h[i] = ker_sort[i];
row_h[i] = row_sort[i];
col_h[i] = col_sort[i];
val_h[i] = val_sort[i];
}
delete [] Koff;
delete [] ker_sort;
delete [] row_sort;
delete [] col_sort;
delete [] val_sort;
delete[] Koff;
delete[] ker_sort;
delete[] row_sort;
delete[] col_sort;
delete[] val_sort;
// compute rows offsets
nrows = 1;
roff_h[0] = 0;
for(int64_t i = 1; i < nnz; i++) {
for (int64_t i = 1; i < nnz; i++) {
if (row_h[i-1] == row_h[i]) continue;
if (row_h[i - 1] == row_h[i]) continue;
roff_h[nrows++] = i;
if (nrows > Ho*K) {
fprintf(stderr,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n",
__FILE__, __LINE__, int64_t(Ho)*K);
if (nrows > Ho * K) {
fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
int64_t(Ho) * K);
exit(EXIT_FAILURE);
}
}
......@@ -106,13 +95,9 @@ void preprocess_psi_kernel(int64_t nnz,
return;
}
torch::Tensor preprocess_psi(const int64_t K,
const int64_t Ho,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val) {
torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val)
{
CHECK_INPUT_TENSOR(ker_idx);
CHECK_INPUT_TENSOR(row_idx);
......@@ -123,33 +108,27 @@ torch::Tensor preprocess_psi(const int64_t K,
int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho*K+1];
int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows;
//float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&]{
preprocess_psi_kernel<scalar_t>(nnz, K, Ho,
ker_h,
row_h,
col_h,
roff_h,
val.data_ptr<scalar_t>(),
nrows);
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
val.data_ptr<scalar_t>(), nrows);
}));
// create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype());
auto roff_idx = torch::empty({nrows+1}, options);
auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for(int64_t i = 0; i < (nrows+1); i++) {
roff_out_h[i] = roff_h[i];
}
delete [] roff_h;
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
delete[] roff_h;
return roff_idx;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
}
......@@ -31,9 +31,8 @@
#include "disco.h"
#include "disco_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment