Commit c46b6925 authored by Max Rietmann's avatar Max Rietmann
Browse files

Applied new formatting

parent 1ea5c4ca
--- ---
BasedOnStyle: Webkit BasedOnStyle: Webkit
IndentWidth: 2 IndentWidth: 4
AccessModifierOffset: -2 AccessModifierOffset: -2
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
AlignTrailingComments: true AlignTrailingComments: true
...@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false ...@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon BreakConstructorInitializers: AfterColon
ColumnLimit: 120 ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2 ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 2 ContinuationIndentWidth: 4
Cpp11BracedListStyle: true Cpp11BracedListStyle: true
FixNamespaceComments: true FixNamespaceComments: true
NamespaceIndentation: All NamespaceIndentation: All
......
...@@ -51,17 +51,17 @@ ...@@ -51,17 +51,17 @@
#define THREADS (64) #define THREADS (64)
#endif #endif
#ifndef DIV_UP #ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif #endif
#ifndef CHECK_CUDA #ifndef CHECK_CUDA
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) \
{ \ { \
cudaError_t err = call; \ cudaError_t err = call; \
if (cudaSuccess != err) { \ if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ } \
} }
#endif #endif
#include <iostream> #include <iostream>
...@@ -70,41 +70,42 @@ ...@@ -70,41 +70,42 @@
class ScopeTimer class ScopeTimer
{ {
public: public:
explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now()) explicit ScopeTimer(const std::string &label = "") :
{ label_(label), start_(std::chrono::high_resolution_clock::now())
} {
}
~ScopeTimer()
{ ~ScopeTimer()
auto end = std::chrono::high_resolution_clock::now(); {
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_); auto end = std::chrono::high_resolution_clock::now();
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
} std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
std::string label_; private:
std::chrono::high_resolution_clock::time_point start_; std::string label_;
std::chrono::high_resolution_clock::time_point start_;
}; };
static __device__ float __warp_sum(float val) static __device__ float __warp_sum(float val)
{ {
#pragma unroll #pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); } for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) static __device__ float __warp_sum_cub(float val)
{ {
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0) // 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads // 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0); sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum; return sum;
} }
// This kernel computes the backward pass for the S2 attention mechanism, using // This kernel computes the backward pass for the S2 attention mechanism, using
...@@ -113,107 +114,107 @@ static __device__ float __warp_sum_cub(float val) ...@@ -113,107 +114,107 @@ static __device__ float __warp_sum_cub(float val)
// memory access. // memory access.
template <int BDIM_X> template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out, int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{ {
extern __shared__ float sh[]; extern __shared__ float sh[];
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5; float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float *sh_alpha_vw = sh_alpha_k + num_channels; float *sh_alpha_vw = sh_alpha_k + num_channels;
float *sh_alpha_kvw = sh_alpha_vw + num_channels; float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels; float *sh_dy = sh_alpha_kvw + num_channels;
float *sh_qy = sh_dy + num_channels; float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates) // (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y; const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return; if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out); const int wo = wid - (ho * nlon_out);
// Zero shared memory // Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; sh_alpha_k[chan] = 0.0f;
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
} }
qdotk = __warp_sum_cub(qdotk); float alpha_sum = 0.0f;
gdotv = __warp_sum_cub(gdotv); float qdotk_max = -FLT_MAX;
float qdotk_max_tmp = max(qdotk_max, qdotk); float integral = 0.0f;
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; __syncthreads();
float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz; const int64_t rbeg = psi_row_offset[ho];
integral = integral * max_correction + alpha_inz * gdotv; const int64_t rend = psi_row_offset[ho + 1];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { const int rlen = rend - rbeg;
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval; // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv; for (int off = 0; off < rlen; off++) {
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv; const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float qdotk_max_tmp = max(qdotk_max, qdotk);
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
}
qdotk_max = qdotk_max_tmp;
} }
qdotk_max = qdotk_max_tmp;
} integral /= alpha_sum;
integral /= alpha_sum; // Write dydq
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; dydq[batchId][chan][ho][wo]
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
} }
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv); // Third pass: accumulate gradients for k and v
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; for (int off = 0; off < rlen; off++) {
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { const int64_t col = psi_col_idx[rbeg + off];
float qyval = qy[batchId][chan][ho][wo]; const int hi = col / nlon_in;
float dyval = sh_dy[chan]; const int wi = col - (hi * nlon_in);
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral)); const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval); float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][chan][ho][wo];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
}
} }
}
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
...@@ -222,122 +223,122 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -222,122 +223,122 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int nlon_in, int nlat_out, int nlon_out) int nlon_in, int nlat_out, int nlon_out)
{ {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy); CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights); CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy); CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto k_channel_first = kx.strides()[1] == 1; auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1; auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1; auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1; auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel] // Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs"); // auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); auto kxP = at::Tensor();
if (!k_channel_first) { if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
kxP = kx; kxP = kx;
} }
auto vxP = at::Tensor(); auto vxP = at::Tensor();
if (!v_channel_first) { if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
vxP = vx; vxP = vx;
} }
auto qyP = at::Tensor(); auto qyP = at::Tensor();
if (!q_channel_first) { if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
qyP = qy; qyP = qy;
} }
auto dyP = at::Tensor(); auto dyP = at::Tensor();
if (!dy_channel_first) { if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
dyP = dy; dyP = dy;
} }
// cudaDeviceSynchronize(); // cudaDeviceSynchronize();
// delete permute_timer; // delete permute_timer;
nvtxRangePop(); nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydk = torch::zeros_like(qyP); auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP); auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP); auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP // print strdie of dydkP, dydvP, dydqP
nvtxRangePop(); nvtxRangePop();
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE); dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop; cudaEvent_t start, stop;
float milliseconds = 0; float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start)); CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>( s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()); quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream)); CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms // s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms // s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms // s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms // s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds); printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels // Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch, // first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo] // channel, ho, wo]
if (!k_channel_first) dydk = dydk.contiguous(); if (!k_channel_first) dydk = dydk.contiguous();
if (!v_channel_first) dydv = dydv.contiguous(); if (!v_channel_first) dydv = dydv.contiguous();
if (!q_channel_first) dydq = dydq.contiguous(); if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:["); // printf("dydk strides:[");
// for(auto& stride : dydk.strides()) { // for(auto& stride : dydk.strides()) {
// printf("%ld,", stride); // printf("%ld,", stride);
// } // }
// printf("]\n"); // printf("]\n");
// cudaDeviceSynchronize(); // cudaDeviceSynchronize();
// delete permute_output_timer; // delete permute_output_timer;
// nvtxRangePop(); // nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} }
...@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>; ...@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32) #define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF) #define FULL_MASK (0xFFFFFFFF)
#define THREADS (64) #define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define NNZ_TRESH (32) #define NNZ_TRESH (32)
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) \
{ \ { \
cudaError_t err = call; \ cudaError_t err = call; \
if (cudaSuccess != err) { \ if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ } \
} }
#define CHECK_ERROR(errorMessage) \ #define CHECK_ERROR(errorMessage) \
{ \ { \
cudaError_t err = cudaGetLastError(); \ cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \ if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \ fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \ cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ } \
} }
static __device__ float __warp_sum(float val) static __device__ float __warp_sum(float val)
{ {
#pragma unroll #pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); } for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) static __device__ float __warp_sum_cub(float val)
{ {
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0) // 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads // 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0); sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum; return sum;
} }
// one warp per (ho,wo) // one warp per (ho,wo)
template <int BDIM_X> template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_kernel( __global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out, int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{ {
extern __shared__ float sh[]; extern __shared__ float sh[];
float *shy = sh + threadIdx.y * num_channels; float *shy = sh + threadIdx.y * num_channels;
const uint64_t batchId = blockIdx.y; const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) { return; } if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out); const int wo = wid - (ho * nlon_out);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0 #if 0
// useless read, y is always zeroed before kernel is called // useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo]; shy[chan] = y[batchId][chan][ho][wo];
#else #else
shy[chan] = 0; shy[chan] = 0;
#endif #endif
} }
float alpha_sum = 0.0f; float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho]; const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1]; const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg; const int rlen = rend - rbeg;
for (int off = 0; off < rlen; off++) { for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off]; const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in); const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f; float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
} }
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
float qdotk_max_tmp; float qdotk_max_tmp;
float alpha; float alpha;
float exp_save; float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk); qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp); exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum * exp_save; alpha_sum = alpha + alpha_sum * exp_save;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha; shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
}
qdotk_max = qdotk_max_tmp;
} }
qdotk_max = qdotk_max_tmp;
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; } for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
return; return;
} }
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
...@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, ...@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
int nlon_out) int nlon_out)
{ {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy); CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights); CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes // TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
auto k_channel_first = kx.strides()[1] == 1; auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1; auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1; auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for // transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); auto kxP = at::Tensor();
if (!k_channel_first) { if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
kxP = kx; kxP = kx;
} }
auto vxP = at::Tensor(); auto vxP = at::Tensor();
if (!v_channel_first) { if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
vxP = vx; vxP = vx;
} }
auto qyP = at::Tensor(); auto qyP = at::Tensor();
if (!q_channel_first) { if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { } else {
qyP = qy; qyP = qy;
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
nvtxRangePop(); nvtxRangePop();
torch::Tensor y = torch::empty_like(qy); torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE); dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * block.y; size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop; cudaEvent_t start, stop;
float milliseconds = 0; float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start)); CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>( s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()); quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream)); CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds); // printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout // match output layout to input layout
if (!q_channel_first) y = y.contiguous(); if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return y; return y;
} }
...@@ -33,6 +33,6 @@ ...@@ -33,6 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
} }
...@@ -37,10 +37,10 @@ ...@@ -37,10 +37,10 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) \ #define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \ CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define MIN_THREADS (64) #define MIN_THREADS (64)
#define ELXTH_MAX (32) #define ELXTH_MAX (32)
......
...@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx]; int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1]; int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff]; const int64_t ker = kers[soff];
const int64_t row = rows[soff]; const int64_t row = rows[soff];
inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi; inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi;
out += bidy * Ho * Wo; out += bidy * Ho * Wo;
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale] extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr); REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr);
// copy current inp row in regs // copy current inp row in regs
REAL_T __reg[ELXTH]; REAL_T __reg[ELXTH];
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); } for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); }
// reset shared row up to Wo+2, remaining // reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations // ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to // will be written to but never copied to
// global mem // global mem
for (int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; } for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; }
} }
__syncthreads(); __syncthreads();
int col_prev = cols[soff]; int col_prev = cols[soff];
int h_prev = col_prev / Wo; int h_prev = col_prev / Wo;
int w_prev = col_prev % Wo; int w_prev = col_prev % Wo;
// loops along the colums of CTA's row // loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) { for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz]; const int col = cols[nz];
const REAL_T val = vals[nz]; const REAL_T val = vals[nz];
// if we are processing a nz with a col value // if we are processing a nz with a col value
// leading to a new row of inp then copy it // leading to a new row of inp then copy it
// to shmem; // to shmem;
// we read a col that points to a new output // we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo) // row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev - w_prev + Wo) { if (col >= col_prev - w_prev + Wo) {
__syncthreads(); __syncthreads();
for (int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) { for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v); atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
__sh[i][j] = 0; __sh[i][j] = 0;
__sh[i][Wi + j] = 0; __sh[i][Wi + j] = 0;
} }
} }
__syncthreads(); __syncthreads();
col_prev = col; col_prev = col;
h_prev = col / Wo; h_prev = col / Wo;
w_prev = col % Wo; w_prev = col % Wo;
} }
const int w = w_prev + (col - col_prev); const int w = w_prev + (col - col_prev);
const int w_mod_ps = w % pscale; const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale; const int w_div_ps = w / pscale;
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid; const int pp = i * BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val * __reg[i]; __sh[w_mod_ps][w_div_ps + pp] += val * __reg[i];
} }
// to avoid race conditions on __sh[] // to avoid race conditions on __sh[]
// among consecutive iterations along nz // among consecutive iterations along nz
__syncthreads();
}
__syncthreads(); __syncthreads();
}
__syncthreads();
// write last row // write last row
for (int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) { for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v); atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
}
} }
} return;
return;
} }
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T> template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__ __global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, __launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff, const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
if constexpr (PSCALE != 0) { if constexpr (PSCALE != 0) {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out);
} else { } else {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
} }
return; return;
} }
template <int NTH, int ELXTH, typename REAL_T> template <int NTH, int ELXTH, typename REAL_T>
...@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t ...@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t stream) cudaStream_t stream)
{ {
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8); static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) { if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wi) { if (NTH * ELXTH >= Wi) {
dim3 grid(nrows, BC); dim3 grid(nrows, BC);
const int pscale = Wo / Wi; const int pscale = Wo / Wi;
size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale); size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale);
switch (pscale) { switch (pscale) {
case 1: case 1:
disco_bwd_blk_k<NTH, ELXTH, 1> disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
case 2: case 2:
disco_bwd_blk_k<NTH, ELXTH, 2> disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
case 3: case 3:
disco_bwd_blk_k<NTH, ELXTH, 3> disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
break; break;
default: default:
disco_bwd_blk_k<NTH, ELXTH, 0> disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); row_d, col_d, val_d, inp_d, out_d);
} }
} else { } else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
stream); out_d, stream);
}
} }
} return;
return;
} }
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{ {
// some sanity checks // some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp); CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val); CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes // extract some shapes
int64_t B = inp.size(0); int64_t B = inp.size(0);
int64_t C = inp.size(1); int64_t C = inp.size(1);
int64_t BC = B * C; int64_t BC = B * C;
int64_t Hi = inp.size(3); int64_t Hi = inp.size(3);
int64_t Wi = inp.size(4); int64_t Wi = inp.size(4);
int64_t nrows = roff_idx.size(0) - 1; int64_t nrows = roff_idx.size(0) - 1;
// allocate output // allocate output
int64_t out_dims[] = {B, C, Ho, Wo}; int64_t out_dims[] = {B, C, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options); torch::Tensor out = torch::zeros(out_dims, options);
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert // assert
static_assert(0 == (ELXTH_MAX % 2)); static_assert(0 == (ELXTH_MAX % 2));
if (Wo <= 64 * ELXTH_MAX) { if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>( launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
})); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} else if (Wo <= 128 * ELXTH_MAX) { }));
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { } else if (Wo <= 128 * ELXTH_MAX) {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
})); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
} else if (Wo <= 256 * ELXTH_MAX) { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { }));
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( } else if (Wo <= 256 * ELXTH_MAX) {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
})); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
} else if (Wo <= 512 * ELXTH_MAX) { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( }));
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), } else if (Wo <= 512 * ELXTH_MAX) {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
})); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
} else if (Wo <= 1024 * ELXTH_MAX) { ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), }));
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), } else if (Wo <= 1024 * ELXTH_MAX) {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
})); launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
} else { BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
1024 * ELXTH_MAX); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
exit(EXIT_FAILURE); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} }));
} else {
return out; fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
} }
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx]; int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1]; int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff]; const int64_t ker = kers[soff];
const int64_t row = rows[soff]; const int64_t row = rows[soff];
inp += bidy * Hi * Wi; inp += bidy * Hi * Wi;
out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo; out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo;
REAL_T __reg[ELXTH] = {0}; REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] extern __shared__ __align__(
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr); sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff]; int col_prev = cols[soff];
int h_prev = col_prev / Wi; int h_prev = col_prev / Wi;
int w_prev = col_prev % Wi; int w_prev = col_prev % Wi;
// copy current inp row in shmem // copy current inp row in shmem
for (int i = tid; i < Wi; i += BDIM_X) { for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i]; const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v; __sh[i] = v;
__sh[Wi + i] = v; __sh[Wi + i] = v;
} }
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used // locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads(); __syncthreads();
// loops along the colums of CTA's row // loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) { for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz]; const int col = cols[nz];
const REAL_T val = vals[nz]; const REAL_T val = vals[nz];
// if we are processing a nz with a col value // if we are processing a nz with a col value
// leading to a new row of inp then copy it // leading to a new row of inp then copy it
// to shmem; // to shmem;
// checks whether (h_prev < h) with: // checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi) // (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev - w_prev + Wi) { if (col >= col_prev - w_prev + Wi) {
col_prev = col; col_prev = col;
h_prev = col / Wi; h_prev = col / Wi;
w_prev = col % Wi; w_prev = col % Wi;
__syncthreads(); __syncthreads();
for (int i = tid; i < Wi; i += BDIM_X) { for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i]; const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v; __sh[i] = v;
__sh[Wi + i] = v; __sh[Wi + i] = v;
} }
__syncthreads(); __syncthreads();
} }
const int w = w_prev + (col - col_prev); const int w = w_prev + (col - col_prev);
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid; const int pp = i * BDIM_X + tid;
// original lines: // original lines:
// //
// if (pp >= Wo) break; // if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi; // const int wpp = (w + pscale*pp) % Wi;
// //
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi // value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem, // so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod // replicate the current inp row and avoid the costly mod
// //
// also, to avoid the conditional, sh can be extended to // also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop // cover the maximum location accessed during this loop
// //
// REAL_T __sh[2*Wi + ppscale*NUM_REM] // REAL_T __sh[2*Wi + ppscale*NUM_REM]
// //
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) = // Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) = // = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = 2*Wi + ppscale*NUM_REM // = 2*Wi + ppscale*NUM_REM
// //
// with NUM_REM = BDIM_X*ELXTH - Wo // with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale * pp; const int wpp = w + pscale * pp;
__reg[i] += val * __sh[wpp]; __reg[i] += val * __sh[wpp];
}
} }
}
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid; const int pp = i * BDIM_X + tid;
if (pp >= Wo) break; if (pp >= Wo) break;
out[pp] = __reg[i]; out[pp] = __reg[i];
} }
return; return;
} }
template <int BDIM_X, int ELXTH, typename REAL_T> template <int BDIM_X, int ELXTH, typename REAL_T>
__global__ __global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, __launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff, const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
return; return;
} }
template <int NTH, int ELXTH, typename REAL_T> template <int NTH, int ELXTH, typename REAL_T>
...@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t ...@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t stream) cudaStream_t stream)
{ {
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8); static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) { if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wo) { if (NTH * ELXTH >= Wo) {
dim3 grid(nrows, BC); dim3 grid(nrows, BC);
const int pscale = Wi / Wo; const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo)); size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH> disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d,
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d); col_d, val_d, inp_d, out_d);
} else { } else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
stream); out_d, stream);
}
} }
} return;
return;
} }
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo) torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{ {
// some sanity checks // some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp); CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx); CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx); CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx); CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val); CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes // extract some shapes
int64_t B = inp.size(0); int64_t B = inp.size(0);
int64_t C = inp.size(1); int64_t C = inp.size(1);
int64_t BC = B * C; int64_t BC = B * C;
int64_t Hi = inp.size(2); int64_t Hi = inp.size(2);
int64_t Wi = inp.size(3); int64_t Wi = inp.size(3);
int64_t nrows = roff_idx.size(0) - 1; int64_t nrows = roff_idx.size(0) - 1;
// allocate output // allocate output
int64_t out_dims[] = {B, C, K, Ho, Wo}; int64_t out_dims[] = {B, C, K, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype()); auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options); torch::Tensor out = torch::zeros(out_dims, options);
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert // assert
static_assert(0 == (ELXTH_MAX % 2)); static_assert(0 == (ELXTH_MAX % 2));
// pick the correct launch config // pick the correct launch config
if (Wo <= 64 * ELXTH_MAX) { if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>( launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
})); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} else if (Wo <= 128 * ELXTH_MAX) { }));
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { } else if (Wo <= 128 * ELXTH_MAX) {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>( AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
})); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
} else if (Wo <= 256 * ELXTH_MAX) { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { }));
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>( } else if (Wo <= 256 * ELXTH_MAX) {
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
})); ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
} else if (Wo <= 512 * ELXTH_MAX) { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>( }));
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), } else if (Wo <= 512 * ELXTH_MAX) {
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
})); BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
} else if (Wo <= 1024 * ELXTH_MAX) { ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>( inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(), }));
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(), } else if (Wo <= 1024 * ELXTH_MAX) {
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream); AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
})); launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
} else { BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo, ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
1024 * ELXTH_MAX); col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
exit(EXIT_FAILURE); inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
} }));
} else {
return out; fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
} }
...@@ -33,6 +33,6 @@ ...@@ -33,6 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)"); m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)"); m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment