Commit c46b6925 authored by Max Rietmann's avatar Max Rietmann
Browse files

Applied new formatting

parent 1ea5c4ca
---
BasedOnStyle: Webkit
IndentWidth: 2
IndentWidth: 4
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignTrailingComments: true
......@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 2
ContinuationIndentWidth: 2
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
FixNamespaceComments: true
NamespaceIndentation: All
......
......@@ -51,17 +51,17 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#endif
#include <iostream>
......@@ -70,41 +70,42 @@
class ScopeTimer
{
public:
explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
~ScopeTimer()
{
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
public:
explicit ScopeTimer(const std::string &label = "") :
label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
~ScopeTimer()
{
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
};
static __device__ float __warp_sum(float val)
{
#pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
// This kernel computes the backward pass for the S2 attention mechanism, using
......@@ -113,107 +114,107 @@ static __device__ float __warp_sum_cub(float val)
// memory access.
template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[];
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float *sh_alpha_vw = sh_alpha_k + num_channels;
float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels;
float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
// Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
extern __shared__ float sh[];
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float *sh_alpha_vw = sh_alpha_k + num_channels;
float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels;
float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return;
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
// Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float qdotk_max_tmp = max(qdotk_max, qdotk);
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float qdotk_max_tmp = max(qdotk_max, qdotk);
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
}
qdotk_max = qdotk_max_tmp;
}
qdotk_max = qdotk_max_tmp;
}
integral /= alpha_sum;
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
integral /= alpha_sum;
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][chan][ho][wo];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][chan][ho][wo];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
}
}
}
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
......@@ -222,122 +223,122 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int nlon_in, int nlat_out, int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
auto dyP = at::Tensor();
if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
dyP = dy;
}
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if (!k_channel_first) dydk = dydk.contiguous();
if (!v_channel_first) dydv = dydv.contiguous();
if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq);
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1;
// Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
auto dyP = at::Tensor();
if (!dy_channel_first) {
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
dyP = dy;
}
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydk = torch::zeros_like(qyP);
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if (!k_channel_first) dydk = dydk.contiguous();
if (!v_channel_first) dydv = dydv.contiguous();
if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq);
}
......@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
static __device__ float __warp_sum(float val)
{
#pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
// one warp per (ho,wo)
template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[];
float *shy = sh + threadIdx.y * num_channels;
extern __shared__ float sh[];
float *shy = sh + threadIdx.y * num_channels;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
const int tidx = threadIdx.x;
const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
#else
shy[chan] = 0;
shy[chan] = 0;
#endif
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1];
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
const int rlen = rend - rbeg;
for (int off = 0; off < rlen; off++) {
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
float qdotk_max_tmp;
float alpha;
float exp_save;
float qdotk_max_tmp;
float alpha;
float exp_save;
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
qdotk_max_tmp = max(qdotk_max, qdotk);
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum * exp_save;
alpha_sum = alpha + alpha_sum * exp_save;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
}
qdotk_max = qdotk_max_tmp;
}
qdotk_max = qdotk_max_tmp;
}
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
return;
return;
}
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
......@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK();
return y;
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
// TODO: check sizes
auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor();
if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize();
nvtxRangePop();
torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop;
float milliseconds = 0;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK();
return y;
}
......@@ -33,6 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
}
......@@ -37,10 +37,10 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
......
......@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
const int tid = threadIdx.x;
const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc
const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1];
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi;
out += bidy * Ho * Wo;
inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi;
out += bidy * Ho * Wo;
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr);
REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr);
// copy current inp row in regs
REAL_T __reg[ELXTH];
// copy current inp row in regs
REAL_T __reg[ELXTH];
#pragma unroll
for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); }
for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); }
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// global mem
for (int i = 0; i < pscale; i++) {
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// global mem
for (int i = 0; i < pscale; i++) {
#pragma unroll
for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; }
}
__syncthreads();
for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; }
}
__syncthreads();
int col_prev = cols[soff];
int col_prev = cols[soff];
int h_prev = col_prev / Wo;
int w_prev = col_prev % Wo;
int h_prev = col_prev / Wo;
int w_prev = col_prev % Wo;
// loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) {
// loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz];
const REAL_T val = vals[nz];
const int col = cols[nz];
const REAL_T val = vals[nz];
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev - w_prev + Wo) {
__syncthreads();
for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) {
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev - w_prev + Wo) {
__syncthreads();
for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
__sh[i][j] = 0;
__sh[i][Wi + j] = 0;
}
}
__syncthreads();
__sh[i][j] = 0;
__sh[i][Wi + j] = 0;
}
}
__syncthreads();
col_prev = col;
h_prev = col / Wo;
w_prev = col % Wo;
}
col_prev = col;
h_prev = col / Wo;
w_prev = col % Wo;
}
const int w = w_prev + (col - col_prev);
const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale;
const int w = w_prev + (col - col_prev);
const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale;
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val * __reg[i];
}
const int pp = i * BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val * __reg[i];
}
// to avoid race conditions on __sh[]
// among consecutive iterations along nz
// to avoid race conditions on __sh[]
// among consecutive iterations along nz
__syncthreads();
}
__syncthreads();
}
__syncthreads();
// write last row
for (int i = 0; i < pscale; i++) {
// write last row
for (int i = 0; i < pscale; i++) {
for (int j = tid; j < Wi; j += BDIM_X) {
for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
}
}
}
return;
return;
}
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
if constexpr (PSCALE != 0) {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out);
} else {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
}
if constexpr (PSCALE != 0) {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out);
} else {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
}
return;
return;
}
template <int NTH, int ELXTH, typename REAL_T>
......@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t stream)
{
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wi) {
dim3 grid(nrows, BC);
const int pscale = Wo / Wi;
size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale);
switch (pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wi) {
dim3 grid(nrows, BC);
const int pscale = Wo / Wi;
size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale);
switch (pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d,
row_d, col_d, val_d, inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
out_d, stream);
}
}
}
return;
return;
}
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes
int64_t B = inp.size(0);
int64_t C = inp.size(1);
int64_t BC = B * C;
int64_t Hi = inp.size(3);
int64_t Wi = inp.size(4);
int64_t nrows = roff_idx.size(0) - 1;
// allocate output
int64_t out_dims[] = {B, C, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX % 2));
if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes
int64_t B = inp.size(0);
int64_t C = inp.size(1);
int64_t BC = B * C;
int64_t Hi = inp.size(3);
int64_t Wi = inp.size(4);
int64_t nrows = roff_idx.size(0) - 1;
// allocate output
int64_t out_dims[] = {B, C, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX % 2));
if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
}
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
const int tid = threadIdx.x;
const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc
const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1];
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
inp += bidy * Hi * Wi;
out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo;
inp += bidy * Hi * Wi;
out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo;
REAL_T __reg[ELXTH] = {0};
REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
// align to larger supported fp type
extern __shared__ __align__(
sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff];
int col_prev = cols[soff];
int h_prev = col_prev / Wi;
int w_prev = col_prev % Wi;
int h_prev = col_prev / Wi;
int w_prev = col_prev % Wi;
// copy current inp row in shmem
for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v;
__sh[Wi + i] = v;
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads();
// copy current inp row in shmem
for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v;
__sh[Wi + i] = v;
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads();
// loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) {
// loops along the colums of CTA's row
for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz];
const REAL_T val = vals[nz];
const int col = cols[nz];
const REAL_T val = vals[nz];
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev - w_prev + Wi) {
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev - w_prev + Wi) {
col_prev = col;
h_prev = col / Wi;
w_prev = col % Wi;
col_prev = col;
h_prev = col / Wi;
w_prev = col % Wi;
__syncthreads();
for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v;
__sh[Wi + i] = v;
}
__syncthreads();
}
__syncthreads();
for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev * Wi + i];
__sh[i] = v;
__sh[Wi + i] = v;
}
__syncthreads();
}
const int w = w_prev + (col - col_prev);
const int w = w_prev + (col - col_prev);
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid;
// original lines:
//
// if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi;
//
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod
//
// also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop
//
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
//
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = 2*Wi + ppscale*NUM_REM
//
// with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale * pp;
__reg[i] += val * __sh[wpp];
for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid;
// original lines:
//
// if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi;
//
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod
//
// also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop
//
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
//
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = 2*Wi + ppscale*NUM_REM
//
// with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale * pp;
__reg[i] += val * __sh[wpp];
}
}
}
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
for (int i = 0; i < ELXTH; i++) {
const int pp = i * BDIM_X + tid;
if (pp >= Wo) break;
const int pp = i * BDIM_X + tid;
if (pp >= Wo) break;
out[pp] = __reg[i];
}
out[pp] = __reg[i];
}
return;
return;
}
template <int BDIM_X, int ELXTH, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
return;
return;
}
template <int NTH, int ELXTH, typename REAL_T>
......@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t stream)
{
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wo) {
dim3 grid(nrows, BC);
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wo) {
dim3 grid(nrows, BC);
const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
const int pscale = Wi / Wo;
size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
disco_fwd_blk_k<NTH, ELXTH>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d,
col_d, val_d, inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d,
out_d, stream);
}
}
}
return;
return;
}
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes
int64_t B = inp.size(0);
int64_t C = inp.size(1);
int64_t BC = B * C;
int64_t Hi = inp.size(2);
int64_t Wi = inp.size(3);
int64_t nrows = roff_idx.size(0) - 1;
// allocate output
int64_t out_dims[] = {B, C, K, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX % 2));
// pick the correct launch config
if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes
int64_t B = inp.size(0);
int64_t C = inp.size(1);
int64_t BC = B * C;
int64_t Hi = inp.size(2);
int64_t Wi = inp.size(3);
int64_t nrows = roff_idx.size(0) - 1;
// allocate output
int64_t out_dims[] = {B, C, K, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX % 2));
// pick the correct launch config
if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 128 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 256 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 512 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else if (Wo <= 1024 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
}));
} else {
fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
1024 * ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
}
......@@ -33,6 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment