Commit 373f9b0b authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

formatting

parent ebc122eb
...@@ -36,16 +36,11 @@ ...@@ -36,16 +36,11 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
at::Tensor qy, at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
at::Tensor psi_col_idx, int nlon_out);
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor qy, at::Tensor dy, at::Tensor quad_weights,
at::Tensor dy, at::Tensor psi_col_idx, at::Tensor psi_row_off,
at::Tensor quad_weights, int nlon_in, int nlat_out, int nlon_out);
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -51,28 +51,32 @@ ...@@ -51,28 +51,32 @@
#define THREADS (64) #define THREADS (64)
#endif #endif
#ifndef DIV_UP #ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b)) #define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#endif #endif
#ifndef CHECK_CUDA #ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \ #define CHECK_CUDA(call) \
cudaError_t err = call; \ { \
if( cudaSuccess != err) { \ cudaError_t err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \ if (cudaSuccess != err) { \
__FILE__, __LINE__, cudaGetErrorString( err) ); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
}} } \
}
#endif #endif
#include <iostream> #include <iostream>
#include <chrono> #include <chrono>
#include <string> #include <string>
class ScopeTimer { class ScopeTimer
{
public: public:
explicit ScopeTimer(const std::string& label = "") explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now())
: label_(label), start_(std::chrono::high_resolution_clock::now()) {} {
}
~ScopeTimer() { ~ScopeTimer()
{
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_); auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
...@@ -83,20 +87,19 @@ private: ...@@ -83,20 +87,19 @@ private:
std::chrono::high_resolution_clock::time_point start_; std::chrono::high_resolution_clock::time_point start_;
}; };
static __device__ float __warp_sum(float val) { static __device__ float __warp_sum(float val)
{
#pragma unroll #pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) { for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) { static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0) // 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads // 2. Broadcast sum to all threads
...@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) { ...@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over // shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced // channels, which should be layed out in the fastest dimension for coalesced
// memory access. // memory access.
template<int BDIM_X> template <int BDIM_X>
__global__ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
__launch_bounds__(BDIM_X) int num_channels, int nlon_in, int nlat_out, int nlon_out,
void s2_attention_bwd_dkvq_kernel( const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
int num_channels, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
int nlon_in, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
int nlat_out, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
int nlon_out, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk, const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv, {
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[]; extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 5; float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels; float *sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels; float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels; float *sh_dy = sh_alpha_kvw + num_channels;
float* sh_qy = sh_dy + num_channels; float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates) // (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y; const uint64_t batchId = blockIdx.y;
...@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X) ...@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
__syncthreads(); __syncthreads();
const int64_t rbeg = psi_row_offset[ho]; const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1]; const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg; const int rlen = rend - rbeg;
// First pass: find qdotk_max // First pass: find qdotk_max
...@@ -166,9 +165,7 @@ __launch_bounds__(BDIM_X) ...@@ -166,9 +165,7 @@ __launch_bounds__(BDIM_X)
const int wi = col - (hi * nlon_in); const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f; float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; }
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk); qdotk_max = max(qdotk_max, qdotk);
} }
...@@ -201,7 +198,8 @@ __launch_bounds__(BDIM_X) ...@@ -201,7 +198,8 @@ __launch_bounds__(BDIM_X)
// Write dydq // Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
} }
// Third pass: accumulate gradients for k and v // Third pass: accumulate gradients for k and v
...@@ -227,16 +225,11 @@ __launch_bounds__(BDIM_X) ...@@ -227,16 +225,11 @@ __launch_bounds__(BDIM_X)
} }
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off,
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, int nlon_in, int nlat_out, int nlon_out)
at::Tensor qy, {
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
...@@ -257,7 +250,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -257,7 +250,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs"); // auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); auto kxP = at::Tensor();
if (!k_channel_first) { if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...@@ -300,8 +293,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -300,8 +293,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE); dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop; cudaEvent_t start, stop;
...@@ -310,20 +303,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -310,20 +303,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<< s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
grid, block, shared_size, stream>>>( uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
uo_num_channels, nlon_in, nlat_out, nlon_out, vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream)); CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
...@@ -333,15 +324,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -333,15 +324,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds); // printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels // Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch, // first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo] // channel, ho, wo]
if(!k_channel_first) dydk = dydk.contiguous(); if (!k_channel_first) dydk = dydk.contiguous();
if(!v_channel_first) dydv = dydv.contiguous(); if (!v_channel_first) dydv = dydv.contiguous();
if(!q_channel_first) dydq = dydq.contiguous(); if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:["); // printf("dydk strides:[");
// for(auto& stride : dydk.strides()) { // for(auto& stride : dydk.strides()) {
...@@ -352,6 +343,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -352,6 +343,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer; // delete permute_output_timer;
// nvtxRangePop(); // nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -45,39 +45,42 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>; ...@@ -45,39 +45,42 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32) #define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF) #define FULL_MASK (0xFFFFFFFF)
#define THREADS (64) #define THREADS (64)
#define DIV_UP(a,b) (((a)+((b)-1))/(b)) #define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define NNZ_TRESH (32) #define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \ #define CHECK_CUDA(call) \
cudaError_t err = call; \ { \
if( cudaSuccess != err) { \ cudaError_t err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ if (cudaSuccess != err) { \
__FILE__, __LINE__, cudaGetErrorString( err) ); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
}} } \
}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \ #define CHECK_ERROR(errorMessage) \
if( cudaSuccess != err) { \ { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ cudaError_t err = cudaGetLastError(); \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \ if (cudaSuccess != err) { \
exit(EXIT_FAILURE); \ fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
}} cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
static __device__ float __warp_sum(float val) { } \
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
} }
static __device__ float __warp_sum(float val)
{
#pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) { static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0) // 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads // 2. Broadcast sum to all threads
...@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) { ...@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
return sum; return sum;
} }
// one warp per (ho,wo) // one warp per (ho,wo)
template<int BDIM_X> template <int BDIM_X>
__global__ __global__ __launch_bounds__(BDIM_X) void s2_attention_kernel(
__launch_bounds__(BDIM_X) int num_channels, int nlon_in, int nlat_out, int nlon_out,
void s2_attention_kernel(int num_channels, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
int nlon_in, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
int nlat_out, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
int nlon_out, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y, {
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[]; extern __shared__ float sh[];
float *shy = sh + threadIdx.y*num_channels; float *shy = sh + threadIdx.y * num_channels;
const uint64_t batchId = blockIdx.y; const uint64_t batchId = blockIdx.y;
const uint64_t wid = uint64_t(blockIdx.x)*blockDim.y + threadIdx.y; const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) { return; }
if (wid >= uint64_t(nlat_out)*nlon_in) {
return;
}
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const int ho = wid / nlon_out;
const int wo = wid - (ho*nlon_out); const int wo = wid - (ho * nlon_out);
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
#if 0 #if 0
// useless read, y is always zeroed before kernel is called // useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo]; shy[chan] = y[batchId][chan][ho][wo];
...@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X) ...@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
float qdotk_max = -FLT_MAX; float qdotk_max = -FLT_MAX;
const int64_t rbeg = psi_row_offset[ho]; const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1]; const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend-rbeg; const int rlen = rend - rbeg;
for(int off = 0; off < rlen; off++) { for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg+off]; const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
const int wi = col - (hi*nlon_in); const int wi = col - (hi * nlon_in);
const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f; float qdotk = 0.0f;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][ wo]* qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
kx[batchId][chan][hi][wip];
} }
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
...@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X) ...@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
exp_save = expf(qdotk_max - qdotk_max_tmp); exp_save = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha + alpha_sum*exp_save; alpha_sum = alpha + alpha_sum * exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][chan][hi][wip]*alpha; shy[chan] = shy[chan] * exp_save + vx[batchId][chan][hi][wip] * alpha;
} }
qdotk_max = qdotk_max_tmp; qdotk_max = qdotk_max_tmp;
} }
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { y[batchId][chan][ho][wo] = shy[chan] / alpha_sum; }
y[batchId][chan][ho][wo] = shy[chan] / alpha_sum;
}
return; return;
} }
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights,
torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor psi_col_idx, at::Tensor psi_row_off, int nlon_in, int nlat_out,
at::Tensor vx, int nlon_out)
at::Tensor qy, {
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in,
int nlat_out,
int nlon_out) {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
...@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
// transpose inputs so that channels are in the last dimension, allowing for // transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); auto kxP = at::Tensor();
if (!k_channel_first) { if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
nvtxRangePop(); nvtxRangePop();
torch::Tensor y = torch::empty_like(qy); torch::Tensor y = torch::empty_like(qy);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE); dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float)*uo_num_channels * block.y; size_t shared_size = sizeof(float) * uo_num_channels * block.y;
cudaEvent_t start, stop; cudaEvent_t start, stop;
float milliseconds = 0; float milliseconds = 0;
...@@ -243,15 +230,14 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -243,15 +230,14 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel<THREADS> s2_attention_kernel<THREADS><<<grid, block, shared_size, stream>>>(
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
y.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream)); CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); CHECK_CUDA(cudaEventSynchronize(stop));
...@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
return y; return y;
} }
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
#include "attention.cuh" #include "attention.cuh"
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -36,32 +36,19 @@ ...@@ -36,32 +36,19 @@
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) CHECK_CUDA_TENSOR(x); CHECK_CONTIGUOUS_TENSOR(x) #define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a,b) (((a)+((b)-1))/(b)) #define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define MIN_THREADS (64) #define MIN_THREADS (64)
#define ELXTH_MAX (32) #define ELXTH_MAX (32)
// forward kernel // forward kernel
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor roff_idx, torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo);
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo);
// backward kernel // backward kernel
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor roff_idx, torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo);
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo);
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -31,239 +31,175 @@ ...@@ -31,239 +31,175 @@
#include "disco.h" #include "disco.h"
#include "disco_cuda.cuh" #include "disco_cuda.cuh"
template <int BDIM_X, int ELXTH, typename REAL_T>
template<int BDIM_X, __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale,
int ELXTH, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers,
typename REAL_T> const int64_t *__restrict__ rows, const int64_t *__restrict__ cols,
__device__ void disco_bwd_d(const int Hi, const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
const int Wi, {
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx]; int64_t soff = roff[bidx];
int64_t eoff = roff[bidx+1]; int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff]; const int64_t ker = kers[soff];
const int64_t row = rows[soff]; const int64_t row = rows[soff];
inp += bidy*K*Hi*Wi + ker*Hi*Wi + row*Wi; inp += bidy * K * Hi * Wi + ker * Hi * Wi + row * Wi;
out += bidy*Ho*Wo; out += bidy * Ho * Wo;
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale] extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T (*__sh)[BDIM_X*ELXTH*2] = reinterpret_cast<REAL_T (*)[BDIM_X*ELXTH*2]>(__sh_ptr); REAL_T(*__sh)[BDIM_X * ELXTH * 2] = reinterpret_cast<REAL_T(*)[BDIM_X * ELXTH * 2]>(__sh_ptr);
// copy current inp row in regs // copy current inp row in regs
REAL_T __reg[ELXTH]; REAL_T __reg[ELXTH];
#pragma unroll #pragma unroll
for(int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) { __reg[i] = (i * BDIM_X + tid < Wi) ? inp[i * BDIM_X + tid] : REAL_T(0); }
__reg[i] = (i*BDIM_X+tid < Wi) ? inp[i*BDIM_X +tid] : REAL_T(0);
}
// reset shared row up to Wo+2, remaining // reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations // ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to // will be written to but never copied to
// global mem // global mem
for(int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
#pragma unroll #pragma unroll
for(int j = 0; j < 2*BDIM_X*ELXTH; j += BDIM_X) { for (int j = 0; j < 2 * BDIM_X * ELXTH; j += BDIM_X) { __sh[i][j + tid] = 0; }
__sh[i][j+tid] = 0;
}
} }
__syncthreads(); __syncthreads();
int col_prev = cols[soff]; int col_prev = cols[soff];
int h_prev = col_prev / Wo; int h_prev = col_prev / Wo;
int w_prev = col_prev % Wo; int w_prev = col_prev % Wo;
// loops along the colums of CTA's row // loops along the colums of CTA's row
for(int64_t nz = soff; nz < eoff; nz++) { for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz]; const int col = cols[nz];
const REAL_T val = vals[nz]; const REAL_T val = vals[nz];
// if we are processing a nz with a col value // if we are processing a nz with a col value
// leading to a new row of inp then copy it // leading to a new row of inp then copy it
// to shmem; // to shmem;
// we read a col that points to a new output // we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo) // row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev-w_prev+Wo) { if (col >= col_prev - w_prev + Wo) {
__syncthreads(); __syncthreads();
for(int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
for(int j = tid; j < Wi; j += BDIM_X) { for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev*Wo + j*pscale + i], v); atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
__sh[i][ j] = 0; __sh[i][j] = 0;
__sh[i][Wi + j] = 0; __sh[i][Wi + j] = 0;
} }
} }
__syncthreads(); __syncthreads();
col_prev = col; col_prev = col;
h_prev = col / Wo; h_prev = col / Wo;
w_prev = col % Wo; w_prev = col % Wo;
} }
const int w = w_prev + (col-col_prev); const int w = w_prev + (col - col_prev);
const int w_mod_ps = w % pscale; const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale; const int w_div_ps = w / pscale;
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid; const int pp = i * BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val*__reg[i]; __sh[w_mod_ps][w_div_ps + pp] += val * __reg[i];
} }
// to avoid race conditions on __sh[] // to avoid race conditions on __sh[]
// among consecutive iterations along nz // among consecutive iterations along nz
__syncthreads(); __syncthreads();
} }
__syncthreads(); __syncthreads();
// write last row // write last row
for(int i = 0; i < pscale; i++) { for (int i = 0; i < pscale; i++) {
for(int j = tid; j < Wi; j += BDIM_X) { for (int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j]; const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev*Wo + j*pscale + i], v); atomicAdd(&out[h_prev * Wo + j * pscale + i], v);
} }
} }
return; return;
} }
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
if constexpr (PSCALE != 0) {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out);
} else {
disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
}
template<int BDIM_X,
int ELXTH,
int PSCALE,
typename REAL_T>
__global__ __launch_bounds__(BDIM_X)
void disco_bwd_blk_k(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
if constexpr(PSCALE != 0) { disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); }
else { disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); }
return; return;
} }
template <int NTH, int ELXTH, typename REAL_T>
static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t nrows, int64_t *roff_d, int64_t *ker_d,
int64_t *row_d, int64_t *col_d, REAL_T *val_d, REAL_T *inp_d, REAL_T *out_d,
cudaStream_t stream)
{
static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
if constexpr (ELXTH <= ELXTH_MAX) {
if (NTH * ELXTH >= Wi) {
dim3 grid(nrows, BC);
template<int NTH, const int pscale = Wo / Wi;
int ELXTH, size_t shmem = sizeof(*out_d) * (2 * (NTH * ELXTH) * pscale);
typename REAL_T>
static void launch_kernel(int BC, switch (pscale) {
int Hi, case 1:
int Wi, disco_bwd_blk_k<NTH, ELXTH, 1>
int K, <<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
int Ho, break;
int Wo, case 2:
int64_t nrows, disco_bwd_blk_k<NTH, ELXTH, 2>
int64_t *roff_d, <<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
int64_t *ker_d, break;
int64_t *row_d, case 3:
int64_t *col_d, disco_bwd_blk_k<NTH, ELXTH, 3>
REAL_T *val_d, <<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
REAL_T *inp_d, break;
REAL_T *out_d, default:
cudaStream_t stream) { disco_bwd_blk_k<NTH, ELXTH, 0>
<<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
static_assert(sizeof(REAL_T) == 2 ||
sizeof(REAL_T) == 4 ||
sizeof(REAL_T) == 8);
if constexpr(ELXTH <= ELXTH_MAX) {
if (NTH*ELXTH >= Wi) {
dim3 grid(nrows, BC);
const int pscale = Wo/Wi;
size_t shmem = sizeof(*out_d)*(2 * (NTH*ELXTH)*pscale);
switch(pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH+1>(BC,
Hi, Wi,
K, Ho, Wo,
nrows,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d,
stream);
} }
} else {
launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
stream);
} }
}
return; return;
} }
torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{
torch::Tensor disco_cuda_bwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo) {
// some sanity checks // some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp); CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(roff_idx);
...@@ -287,87 +223,54 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, ...@@ -287,87 +223,54 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX%2));
// assert
static_assert(0 == (ELXTH_MAX % 2));
if (Wo <= 64*ELXTH_MAX) { if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<64, 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 128 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 128*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 256 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 256*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 512 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 512*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 1024 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 1024*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else {
inp.data_ptr<scalar_t>(), fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
out.data_ptr<scalar_t>(), 1024 * ELXTH_MAX);
stream);
}));
}
else {
fprintf(stderr,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d\n",
__FILE__, __LINE__, Wo, 1024*ELXTH_MAX);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
return out; return out;
} }
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)"); // m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
//} //}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -31,101 +31,90 @@ ...@@ -31,101 +31,90 @@
#include "disco.h" #include "disco.h"
#include "disco_cuda.cuh" #include "disco_cuda.cuh"
template <int BDIM_X, int ELXTH, typename REAL_T>
template<int BDIM_X, __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int Ho, const int Wo, const int pscale,
int ELXTH, const int64_t *__restrict__ roff, const int64_t *__restrict__ kers,
typename REAL_T> const int64_t *__restrict__ rows, const int64_t *__restrict__ cols,
__device__ void disco_fwd_d(const int Hi, const REAL_T *__restrict__ vals, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
const int Wi, {
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx]; int64_t soff = roff[bidx];
int64_t eoff = roff[bidx+1]; int64_t eoff = roff[bidx + 1];
const int64_t ker = kers[soff]; const int64_t ker = kers[soff];
const int64_t row = rows[soff]; const int64_t row = rows[soff];
inp += bidy*Hi*Wi; inp += bidy * Hi * Wi;
out += bidy*K*Ho*Wo + ker*Ho*Wo + row*Wo; out += bidy * K * Ho * Wo + ker * Ho * Wo + row * Wo;
REAL_T __reg[ELXTH] = {0}; REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type // align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)] extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr); REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff]; int col_prev = cols[soff];
int h_prev = col_prev / Wi; int h_prev = col_prev / Wi;
int w_prev = col_prev % Wi; int w_prev = col_prev % Wi;
// copy current inp row in shmem // copy current inp row in shmem
for(int i = tid; i < Wi; i += BDIM_X) { for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev*Wi + i]; const REAL_T v = inp[h_prev * Wi + i];
__sh[ i] = v; __sh[i] = v;
__sh[Wi + i] = v; __sh[Wi + i] = v;
} }
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used // locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads(); __syncthreads();
// loops along the colums of CTA's row // loops along the colums of CTA's row
for(int64_t nz = soff; nz < eoff; nz++) { for (int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz]; const int col = cols[nz];
const REAL_T val = vals[nz]; const REAL_T val = vals[nz];
// if we are processing a nz with a col value // if we are processing a nz with a col value
// leading to a new row of inp then copy it // leading to a new row of inp then copy it
// to shmem; // to shmem;
// checks whether (h_prev < h) with: // checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi) // (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev-w_prev+Wi) { if (col >= col_prev - w_prev + Wi) {
col_prev = col; col_prev = col;
h_prev = col / Wi; h_prev = col / Wi;
w_prev = col % Wi; w_prev = col % Wi;
__syncthreads(); __syncthreads();
for(int i = tid; i < Wi; i += BDIM_X) { for (int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev*Wi + i]; const REAL_T v = inp[h_prev * Wi + i];
__sh[ i] = v; __sh[i] = v;
__sh[Wi + i] = v; __sh[Wi + i] = v;
} }
__syncthreads(); __syncthreads();
} }
const int w = w_prev + (col-col_prev); const int w = w_prev + (col - col_prev);
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid; const int pp = i * BDIM_X + tid;
// original lines: // original lines:
// //
// if (pp >= Wo) break; // if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi; // const int wpp = (w + pscale*pp) % Wi;
// //
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi // value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem, // so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod // replicate the current inp row and avoid the costly mod
// //
// also, to avoid the conditional, sh can be extended to // also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop // cover the maximum location accessed during this loop
// //
// REAL_T __sh[2*Wi + ppscale*NUM_REM] // REAL_T __sh[2*Wi + ppscale*NUM_REM]
...@@ -135,113 +124,68 @@ __device__ void disco_fwd_d(const int Hi, ...@@ -135,113 +124,68 @@ __device__ void disco_fwd_d(const int Hi,
// = 2*Wi + ppscale*NUM_REM // = 2*Wi + ppscale*NUM_REM
// //
// with NUM_REM = BDIM_X*ELXTH - Wo // with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale*pp; const int wpp = w + pscale * pp;
__reg[i] += val*__sh[wpp]; __reg[i] += val * __sh[wpp];
} }
} }
#pragma unroll #pragma unroll
for (int i = 0; i < ELXTH; i++) { for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid; const int pp = i * BDIM_X + tid;
if (pp >= Wo) break; if (pp >= Wo) break;
out[pp] = __reg[i]; out[pp] = __reg[i];
} }
return; return;
} }
template <int BDIM_X, int ELXTH, typename REAL_T>
__global__
template<int BDIM_X, __launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
int ELXTH, const int pscale, const int64_t *__restrict__ roff,
typename REAL_T> const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
__global__ __launch_bounds__(BDIM_X) const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
void disco_fwd_blk_k(const int Hi, const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
const int Wi, {
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
return; return;
} }
template <int NTH, int ELXTH, typename REAL_T>
static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t nrows, int64_t *roff_d, int64_t *ker_d,
int64_t *row_d, int64_t *col_d, REAL_T *val_d, REAL_T *inp_d, REAL_T *out_d,
cudaStream_t stream)
{
template<int NTH, static_assert(sizeof(REAL_T) == 2 || sizeof(REAL_T) == 4 || sizeof(REAL_T) == 8);
int ELXTH,
typename REAL_T> if constexpr (ELXTH <= ELXTH_MAX) {
static void launch_kernel(int BC, if (NTH * ELXTH >= Wo) {
int Hi, dim3 grid(nrows, BC);
int Wi,
int K, const int pscale = Wi / Wo;
int Ho, size_t shmem = sizeof(*out_d) * (Wi * 2 + pscale * (NTH * ELXTH - Wo));
int Wo,
int64_t nrows, disco_fwd_blk_k<NTH, ELXTH>
int64_t *roff_d, <<<grid, NTH, shmem, stream>>>(Hi, Wi, K, Ho, Wo, pscale, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d);
int64_t *ker_d, } else {
int64_t *row_d, launch_kernel<NTH, ELXTH + 1>(BC, Hi, Wi, K, Ho, Wo, nrows, roff_d, ker_d, row_d, col_d, val_d, inp_d, out_d,
int64_t *col_d, stream);
REAL_T *val_d,
REAL_T *inp_d,
REAL_T *out_d,
cudaStream_t stream) {
static_assert(sizeof(REAL_T) == 2 ||
sizeof(REAL_T) == 4 ||
sizeof(REAL_T) == 8);
if constexpr(ELXTH <= ELXTH_MAX) {
if (NTH*ELXTH >= Wo) {
dim3 grid(nrows, BC);
const int pscale = Wi/Wo;
size_t shmem = sizeof(*out_d)*(Wi*2 + pscale*(NTH*ELXTH-Wo));
disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH+1>(BC,
Hi, Wi,
K, Ho, Wo,
nrows,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d,
stream);
}
} }
}
return; return;
} }
torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val, int64_t K, int64_t Ho, int64_t Wo)
{
torch::Tensor disco_cuda_fwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo) {
// some sanity checks // some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp); CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx); CHECK_CUDA_INPUT_TENSOR(roff_idx);
...@@ -265,83 +209,51 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, ...@@ -265,83 +209,51 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert // assert
static_assert(0 == (ELXTH_MAX%2)); static_assert(0 == (ELXTH_MAX % 2));
// pick the correct launch config // pick the correct launch config
if (Wo <= 64*ELXTH_MAX) { if (Wo <= 64 * ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<64, 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 128 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 128*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<128, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 256 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 256*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<256, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 512 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 512*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<512, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else if (Wo <= 1024 * ELXTH_MAX) {
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 1024*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows, launch_kernel<1024, (ELXTH_MAX / 2) + 1, scalar_t>(
roff_idx.data_ptr<int64_t>(), BC, Hi, Wi, K, Ho, Wo, nrows, roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(), ker_idx.data_ptr<int64_t>(), row_idx.data_ptr<int64_t>(), col_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(), val.data_ptr<scalar_t>(), inp.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), stream);
col_idx.data_ptr<int64_t>(), }));
val.data_ptr<scalar_t>(), } else {
inp.data_ptr<scalar_t>(), fprintf(stderr, "%s:%d: error, unsupported Wo value (%ld), max supported is %d\n", __FILE__, __LINE__, Wo,
out.data_ptr<scalar_t>(), 1024 * ELXTH_MAX);
stream);
}));
}
else {
fprintf(stderr,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d\n",
__FILE__, __LINE__, Wo, 1024*ELXTH_MAX);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
return out; return out;
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -30,40 +30,30 @@ ...@@ -30,40 +30,30 @@
#include "disco.h" #include "disco.h"
template<typename REAL_T> template <typename REAL_T>
void preprocess_psi_kernel(int64_t nnz, void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, int64_t *row_h, int64_t *col_h,
int64_t K, int64_t *roff_h, REAL_T *val_h, int64_t &nrows)
int64_t Ho, {
int64_t *ker_h,
int64_t *row_h,
int64_t *col_h,
int64_t *roff_h,
REAL_T *val_h,
int64_t& nrows) {
int64_t *Koff = new int64_t[K]; int64_t *Koff = new int64_t[K];
for(int i = 0; i < K; i++) { for (int i = 0; i < K; i++) { Koff[i] = 0; }
Koff[i] = 0;
}
for(int64_t i = 0; i < nnz; i++) { for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }
Koff[ker_h[i]]++;
}
int64_t prev = Koff[0]; int64_t prev = Koff[0];
Koff[0] = 0; Koff[0] = 0;
for(int i = 1; i < K; i++) { for (int i = 1; i < K; i++) {
int64_t save = Koff[i]; int64_t save = Koff[i];
Koff[i] = prev + Koff[i-1]; Koff[i] = prev + Koff[i - 1];
prev = save; prev = save;
} }
int64_t *ker_sort = new int64_t[nnz]; int64_t *ker_sort = new int64_t[nnz];
int64_t *row_sort = new int64_t[nnz]; int64_t *row_sort = new int64_t[nnz];
int64_t *col_sort = new int64_t[nnz]; int64_t *col_sort = new int64_t[nnz];
float *val_sort = new float[nnz]; float *val_sort = new float[nnz];
for(int64_t i = 0; i < nnz; i++) { for (int64_t i = 0; i < nnz; i++) {
const int64_t ker = ker_h[i]; const int64_t ker = ker_h[i];
const int64_t off = Koff[ker]++; const int64_t off = Koff[ker]++;
...@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz, ...@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
col_sort[off] = col_h[i]; col_sort[off] = col_h[i];
val_sort[off] = val_h[i]; val_sort[off] = val_h[i];
} }
for(int64_t i = 0; i < nnz; i++) { for (int64_t i = 0; i < nnz; i++) {
ker_h[i] = ker_sort[i]; ker_h[i] = ker_sort[i];
row_h[i] = row_sort[i]; row_h[i] = row_sort[i];
col_h[i] = col_sort[i]; col_h[i] = col_sort[i];
val_h[i] = val_sort[i]; val_h[i] = val_sort[i];
} }
delete [] Koff; delete[] Koff;
delete [] ker_sort; delete[] ker_sort;
delete [] row_sort; delete[] row_sort;
delete [] col_sort; delete[] col_sort;
delete [] val_sort; delete[] val_sort;
// compute rows offsets // compute rows offsets
nrows = 1; nrows = 1;
roff_h[0] = 0; roff_h[0] = 0;
for(int64_t i = 1; i < nnz; i++) { for (int64_t i = 1; i < nnz; i++) {
if (row_h[i-1] == row_h[i]) continue; if (row_h[i - 1] == row_h[i]) continue;
roff_h[nrows++] = i; roff_h[nrows++] = i;
if (nrows > Ho*K) { if (nrows > Ho * K) {
fprintf(stderr, fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", int64_t(Ho) * K);
__FILE__, __LINE__, int64_t(Ho)*K);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
} }
...@@ -106,50 +95,40 @@ void preprocess_psi_kernel(int64_t nnz, ...@@ -106,50 +95,40 @@ void preprocess_psi_kernel(int64_t nnz,
return; return;
} }
torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val)
{
torch::Tensor preprocess_psi(const int64_t K,
const int64_t Ho,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val) {
CHECK_INPUT_TENSOR(ker_idx); CHECK_INPUT_TENSOR(ker_idx);
CHECK_INPUT_TENSOR(row_idx); CHECK_INPUT_TENSOR(row_idx);
CHECK_INPUT_TENSOR(col_idx); CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val); CHECK_INPUT_TENSOR(val);
int64_t nnz = val.size(0); int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>(); int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>(); int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>(); int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho*K+1]; int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows; int64_t nrows;
//float *val_h = val.data_ptr<float>(); // float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&]{ AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
ker_h, val.data_ptr<scalar_t>(), nrows);
row_h, }));
col_h,
roff_h,
val.data_ptr<scalar_t>(),
nrows);
}));
// create output tensor // create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype()); auto options = torch::TensorOptions().dtype(row_idx.dtype());
auto roff_idx = torch::empty({nrows+1}, options); auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>(); int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for(int64_t i = 0; i < (nrows+1); i++) { for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
roff_out_h[i] = roff_h[i]; delete[] roff_h;
}
delete [] roff_h;
return roff_idx; return roff_idx;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda."); m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -31,9 +31,8 @@ ...@@ -31,9 +31,8 @@
#include "disco.h" #include "disco.h"
#include "disco_cuda.cuh" #include "disco_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { {
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)"); m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)"); m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment