Commit 68e7d0fa authored by Max Rietmann's avatar Max Rietmann
Browse files

Clang format

parent cb79c766
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -51,28 +51,32 @@ ...@@ -51,28 +51,32 @@
#define THREADS (64) #define THREADS (64)
#endif #endif
#ifndef DIV_UP #ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b)) #define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif #endif
#ifndef CHECK_CUDA #ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \ #define CHECK_CUDA(call) \
cudaError_t err = call; \ { \
if( cudaSuccess != err) { \ cudaError_t err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \ if (cudaSuccess != err) { \
__FILE__, __LINE__, cudaGetErrorString( err) ); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
}} } \
}
#endif #endif
#include <iostream> #include <iostream>
#include <chrono> #include <chrono>
#include <string> #include <string>
class ScopeTimer { class ScopeTimer
{
public: public:
explicit ScopeTimer(const std::string& label = "") explicit ScopeTimer(const std::string &label = "") : label_(label), start_(std::chrono::high_resolution_clock::now())
: label_(label), start_(std::chrono::high_resolution_clock::now()) {} {
}
~ScopeTimer() { ~ScopeTimer()
{
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_); auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
...@@ -83,20 +87,19 @@ private: ...@@ -83,20 +87,19 @@ private:
std::chrono::high_resolution_clock::time_point start_; std::chrono::high_resolution_clock::time_point start_;
}; };
static __device__ float __warp_sum(float val) { static __device__ float __warp_sum(float val)
{
#pragma unroll #pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) { for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) { static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0) // 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads // 2. Broadcast sum to all threads
...@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) { ...@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over // shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced // channels, which should be layed out in the fastest dimension for coalesced
// memory access. // memory access.
template<int BDIM_X> template <int BDIM_X>
__global__ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
__launch_bounds__(BDIM_X) int num_channels, int nlon_in, int nlat_out, int nlon_out,
void s2_attention_bwd_dkvq_kernel( const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
int num_channels, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
int nlon_in, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
int nlat_out, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
int nlon_out, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk, const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv, {
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[]; extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 5; float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels; float *sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels; float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels; float *sh_dy = sh_alpha_kvw + num_channels;
float* sh_qy = sh_dy + num_channels; float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates) // (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y; const uint64_t batchId = blockIdx.y;
...@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X) ...@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
__syncthreads(); __syncthreads();
const int64_t rbeg = psi_row_offset[ho]; const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1]; const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg; const int rlen = rend - rbeg;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max. // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
...@@ -176,7 +175,7 @@ __launch_bounds__(BDIM_X) ...@@ -176,7 +175,7 @@ __launch_bounds__(BDIM_X)
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
float max_correction = expf(qdotk_max - qdotk_max_tmp); float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz; alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv; integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip]; float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval; sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
...@@ -190,7 +189,8 @@ __launch_bounds__(BDIM_X) ...@@ -190,7 +189,8 @@ __launch_bounds__(BDIM_X)
// Write dydq // Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
} }
// Third pass: accumulate gradients for k and v // Third pass: accumulate gradients for k and v
...@@ -216,16 +216,11 @@ __launch_bounds__(BDIM_X) ...@@ -216,16 +216,11 @@ __launch_bounds__(BDIM_X)
} }
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off,
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, int nlon_in, int nlat_out, int nlon_out)
at::Tensor qy, {
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx); CHECK_CUDA_TENSOR(vx);
...@@ -246,7 +241,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -246,7 +241,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs"); // auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto kxP = at::Tensor(); auto kxP = at::Tensor();
if (!k_channel_first) { if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); // printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...@@ -289,8 +284,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -289,8 +284,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t uo_num_channels = kx.size(1); size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE); dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop; cudaEvent_t start, stop;
...@@ -299,20 +294,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -299,20 +294,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<< s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
grid, block, shared_size, stream>>>( uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
uo_num_channels, nlon_in, nlat_out, nlon_out, vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
CHECK_CUDA(cudaEventRecord(stop, stream)); CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); CHECK_CUDA(cudaEventSynchronize(stop));
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
...@@ -328,15 +321,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -328,15 +321,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds); printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to memory layout given by input. if input had channels // Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch, // first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo] // channel, ho, wo]
if(!k_channel_first) dydk = dydk.contiguous(); if (!k_channel_first) dydk = dydk.contiguous();
if(!v_channel_first) dydv = dydv.contiguous(); if (!v_channel_first) dydv = dydv.contiguous();
if(!q_channel_first) dydq = dydq.contiguous(); if (!q_channel_first) dydq = dydq.contiguous();
// printf("dydk strides:["); // printf("dydk strides:[");
// for(auto& stride : dydk.strides()) { // for(auto& stride : dydk.strides()) {
...@@ -347,6 +340,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -347,6 +340,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer; // delete permute_output_timer;
// nvtxRangePop(); // nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq); return std::make_tuple(dydk, dydv, dydq);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment