Commit 3dd35b45 authored by Max Rietmann's avatar Max Rietmann
Browse files

Fixed compile errors for ChannelsLast C++ code, unfortunately also format-on-save

parent e1338191
......@@ -51,49 +51,53 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
}}
} \
}
#endif
#include <iostream>
#include <chrono>
#include <string>
class ScopeTimer {
public:
explicit ScopeTimer(const std::string& label = "")
: label_(label), start_(std::chrono::high_resolution_clock::now()) {}
class ScopeTimer
{
public:
explicit ScopeTimer(const std::string &label = "") :
label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
~ScopeTimer() {
~ScopeTimer()
{
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
};
static __device__ float __warp_sum(float val) {
static __device__ float __warp_sum(float val)
{
#pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) {
val += __shfl_xor_sync(FULL_MASK, val, i);
}
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); }
return val;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
static __device__ float __warp_sum_cub(float val)
{
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
......@@ -108,14 +112,9 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template<int BDIM_X>
__global__
__launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel(
int num_channels,
int nlon_in,
int nlat_out,
int nlon_out,
template <int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
......@@ -125,14 +124,15 @@ __launch_bounds__(BDIM_X)
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float *sh_alpha_vw = sh_alpha_k + num_channels;
float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels;
float* sh_qy = sh_dy + num_channels;
float *sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
......@@ -156,7 +156,7 @@ __launch_bounds__(BDIM_X)
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
// First pass: find qdotk_max
......@@ -201,7 +201,8 @@ __launch_bounds__(BDIM_X)
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
dydq[batchId][chan][ho][wo]
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
......@@ -227,16 +228,11 @@ __launch_bounds__(BDIM_X)
}
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out)
{
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
......@@ -259,16 +255,16 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto dy_type = dy.dtype();
// exract memory format
auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::Channels_last);
auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::Channels_last);
auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::Channels_last);
auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::Channels_last);
auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// convert to channels-last
auto kxP = kx.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
auto vxP = vx.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
auto qyP = qy.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
auto dyP = dy.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// cudaDeviceSynchronize();
// delete permute_timer;
......@@ -284,8 +280,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
......@@ -294,10 +290,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel<THREADS><<<
grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
......@@ -330,20 +324,20 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
dydq = dydq.to(qy_type);
// permute back to original layout
if(!kx_is_channels_last){
dydk = dydk.to(kx_type, at::MemoryFormat::Contiguous);
if (!kx_is_channels_last) {
dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
} else {
dydk = dydk.to(kx_type);
}
if(!vx_is_channels_last){
dydv = dydv.to(vx_type, at::MemoryFormat::Contiguous);
if (!vx_is_channels_last) {
dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
} else {
dydv = dydv.to(vx_type);
}
if(!qy_is_channels_last) {
dydq = dydq.to(qy_type, at::MemoryFormat::Contiguous);
if (!qy_is_channels_last) {
dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
} else {
dydq = dydq.to(qy_type)
dydq = dydq.to(qy_type);
}
// printf("dydk strides: [");
......@@ -355,6 +349,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment