Commit e1338191 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

using torch tools to change layout in bd pass

parent 49a61eee
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// //
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause // SPDX-License-Identifier: BSD-3-Clause
// //
// Redistribution and use in source and binary forms, with or without // Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met: // modification, are permitted provided that the following conditions are met:
// //
...@@ -51,290 +51,310 @@ ...@@ -51,290 +51,310 @@
#define THREADS (64) #define THREADS (64)
#endif #endif
#ifndef DIV_UP #ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b)) #define DIV_UP(a,b) (((a)+((b)-1))/(b))
#endif #endif
#ifndef CHECK_CUDA #ifndef CHECK_CUDA
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) { \
{ \ cudaError_t err = call; \
cudaError_t err = call; \ if( cudaSuccess != err) { \
if (cudaSuccess != err) { \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ }}
}
#endif #endif
#include <iostream> #include <iostream>
#include <chrono> #include <chrono>
#include <string> #include <string>
class ScopeTimer class ScopeTimer {
{ public:
public: explicit ScopeTimer(const std::string& label = "")
explicit ScopeTimer(const std::string &label = "") : : label_(label), start_(std::chrono::high_resolution_clock::now()) {}
label_(label), start_(std::chrono::high_resolution_clock::now())
{
}
~ScopeTimer() ~ScopeTimer() {
{ auto end = std::chrono::high_resolution_clock::now();
auto end = std::chrono::high_resolution_clock::now(); auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_); std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; }
}
private: private:
std::string label_; std::string label_;
std::chrono::high_resolution_clock::time_point start_; std::chrono::high_resolution_clock::time_point start_;
}; };
static __device__ float __warp_sum(float val) static __device__ float __warp_sum(float val) {
{
#pragma unroll #pragma unroll
for (int i = WARP_SIZE / 2; i; i /= 2) { val += __shfl_xor_sync(FULL_MASK, val, i); } for(int i = WARP_SIZE/2; i; i /= 2) {
return val; val += __shfl_xor_sync(FULL_MASK, val, i);
}
return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) static __device__ float __warp_sum_cub(float val) {
{ // use cub to reduce within a warp
// use cub to reduce within a warp __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
// 1. Compute sum (initially only in lane 0) float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
float sum = cub::WarpReduce<float>(temp_storage).Sum(val); // 2. Broadcast sum to all threads
// 2. Broadcast sum to all threads sum = __shfl_sync(0xFFFFFFFF, sum, 0);
sum = __shfl_sync(0xFFFFFFFF, sum, 0); return sum;
return sum;
} }
// This kernel computes the backward pass for the S2 attention mechanism, using // This kernel computes the backward pass for the S2 attention mechanism, using
// shared memory as a cache and one warp per output point, warp-parallel over // shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced // channels, which should be layed out in the fastest dimension for coalesced
// memory access. // memory access.
template <int BDIM_X> template<int BDIM_X>
__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( __global__
int num_channels, int nlon_in, int nlat_out, int nlon_out, __launch_bounds__(BDIM_X)
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, void s2_attention_bwd_dkvq_kernel(
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, int num_channels,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy, int nlon_in,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy, int nlat_out,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk, int nlon_out,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset, const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
{ torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
extern __shared__ float sh[]; const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
float *sh_alpha_k = sh + threadIdx.y * num_channels * 5; const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
float *sh_alpha_vw = sh_alpha_k + num_channels; const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
float *sh_alpha_kvw = sh_alpha_vw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels; extern __shared__ float sh[];
float *sh_qy = sh_dy + num_channels; float* sh_alpha_k = sh + threadIdx.y * num_channels * 5;
// (optionally, could use more shared memory for other intermediates) float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
const uint64_t batchId = blockIdx.y; float *sh_dy = sh_alpha_kvw + num_channels;
const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; float* sh_qy = sh_dy + num_channels;
if (wid >= uint64_t(nlat_out) * nlon_in) return; // (optionally, could use more shared memory for other intermediates)
const int tidx = threadIdx.x;
const int ho = wid / nlon_out; const uint64_t batchId = blockIdx.y;
const int wo = wid - (ho * nlon_out); const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y;
if (wid >= uint64_t(nlat_out) * nlon_in) return;
// Zero shared memory const int tidx = threadIdx.x;
const int ho = wid / nlon_out;
const int wo = wid - (ho * nlon_out);
// Zero shared memory
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
float integral = 0.0f;
__syncthreads();
const int64_t rbeg = psi_row_offset[ho];
const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend - rbeg;
// First pass: find qdotk_max
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
sh_alpha_k[chan] = 0.0f; qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
} }
float alpha_sum = 0.0f; qdotk = __warp_sum_cub(qdotk);
float qdotk_max = -FLT_MAX; qdotk_max = max(qdotk_max, qdotk);
float integral = 0.0f; }
__syncthreads();
// Second pass: accumulate alpha_sum, integral, and shared stats
const int64_t rbeg = psi_row_offset[ho]; for (int off = 0; off < rlen; off++) {
const int64_t rend = psi_row_offset[ho + 1]; const int64_t col = psi_col_idx[rbeg + off];
const int rlen = rend - rbeg; const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max. const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
for (int off = 0; off < rlen; off++) { float qdotk = 0.0f, gdotv = 0.0f;
const int64_t col = psi_col_idx[rbeg + off]; for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
const int hi = col / nlon_in; qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
const int wi = col - (hi * nlon_in); gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float qdotk_max_tmp = max(qdotk_max, qdotk);
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
}
qdotk_max = qdotk_max_tmp;
} }
qdotk = __warp_sum_cub(qdotk);
integral /= alpha_sum; gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// Write dydq alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][chan][ho][wo] float kxval = kx[batchId][chan][hi][wip];
= (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
} }
}
// Third pass: accumulate gradients for k and v
for (int off = 0; off < rlen; off++) { integral /= alpha_sum;
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in; // Write dydq
const int wi = col - (hi * nlon_in); for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
float qdotk = 0.0f, gdotv = 0.0f; }
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; // Third pass: accumulate gradients for k and v
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; for (int off = 0; off < rlen; off++) {
} const int64_t col = psi_col_idx[rbeg + off];
qdotk = __warp_sum_cub(qdotk); const int hi = col / nlon_in;
gdotv = __warp_sum_cub(gdotv); const int wi = col - (hi * nlon_in);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { float qdotk = 0.0f, gdotv = 0.0f;
float qyval = qy[batchId][chan][ho][wo]; for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float dyval = sh_dy[chan]; qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral)); gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
}
} }
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][chan][ho][wo];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
}
}
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
at::Tensor dy, at::Tensor quad_weights,
at::Tensor psi_col_idx, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
{ at::Tensor qy,
at::Tensor dy,
CHECK_CUDA_TENSOR(kx); at::Tensor quad_weights,
CHECK_CUDA_TENSOR(vx); at::Tensor psi_col_idx,
CHECK_CUDA_TENSOR(qy); at::Tensor psi_row_off,
CHECK_CUDA_TENSOR(quad_weights); int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(dy); CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
auto stream = at::cuda::getCurrentCUDAStream().stream(); CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
auto k_channel_first = kx.strides()[1] == 1; CHECK_CUDA_TENSOR(psi_row_off);
auto v_channel_first = vx.strides()[1] == 1; CHECK_CUDA_TENSOR(dy);
auto q_channel_first = qy.strides()[1] == 1;
auto dy_channel_first = dy.strides()[1] == 1; auto stream = at::cuda::getCurrentCUDAStream().stream();
// Transpose to [batch, ho, wo, channel] // Transpose to [batch, ho, wo, channel]
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs"); // auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo] // extract dtype
auto kxP = at::Tensor(); auto kx_type = kx.dtype();
if (!k_channel_first) { auto vx_type = vx.dtype();
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); auto qy_type = qy.dtype();
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); auto dy_type = dy.dtype();
} else {
kxP = kx; // exract memory format
} auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::Channels_last);
auto vxP = at::Tensor(); auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::Channels_last);
if (!v_channel_first) { auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::Channels_last);
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::Channels_last);
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else { // convert to channels-last
vxP = vx; auto kxP = kx.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
} auto vxP = vx.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
auto qyP = at::Tensor(); auto qyP = qy.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
if (!q_channel_first) { auto dyP = dy.to(torch::kFloat32, at::MemoryFormat::ChannelsLast);
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); // cudaDeviceSynchronize();
} else { // delete permute_timer;
qyP = qy; nvtxRangePop();
}
auto dyP = at::Tensor(); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
if (!dy_channel_first) { auto dydk = torch::zeros_like(qyP);
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n"); auto dydv = torch::zeros_like(qyP);
dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2}); auto dydq = torch::zeros_like(qyP);
} else { // print strdie of dydkP, dydvP, dydqP
dyP = dy; nvtxRangePop();
}
// cudaDeviceSynchronize(); size_t uo_num_channels = kx.size(1);
// delete permute_timer; const int batch_size = kx.size(0);
nvtxRangePop();
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"); dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
auto dydk = torch::zeros_like(qyP); size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
auto dydv = torch::zeros_like(qyP);
auto dydq = torch::zeros_like(qyP); cudaEvent_t start, stop;
// print strdie of dydkP, dydvP, dydqP float milliseconds = 0;
nvtxRangePop(); CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
size_t uo_num_channels = kx.size(1); CHECK_CUDA(cudaEventRecord(start, stream));
const int batch_size = kx.size(0);
s2_attention_bwd_dkvq_kernel<THREADS><<<
dim3 block(WARP_SIZE, THREADS / WARP_SIZE); grid, block, shared_size, stream>>>(
dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size); uo_num_channels, nlon_in, nlat_out, nlon_out,
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
cudaEvent_t start, stop; qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
float milliseconds = 0; dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
CHECK_CUDA(cudaEventCreate(&start)); dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
CHECK_CUDA(cudaEventCreate(&stop)); dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
CHECK_CUDA(cudaEventRecord(start, stream)); dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>( psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventRecord(stop, stream));
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventSynchronize(stop));
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), // printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
CHECK_CUDA(cudaEventRecord(stop, stream));
CHECK_CUDA(cudaEventSynchronize(stop)); C10_CUDA_KERNEL_LAUNCH_CHECK();
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// Permute outputs back to memory layout given by input. if input had channels
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // first, leave it in that layout, otherwise permute layout back to [batch,
// s2_attention_bwd_kernel execution time: 50.724865 ms // channel, ho, wo]
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms // convert back to original dtype
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds); dydk = dydk.to(kx_type);
CHECK_CUDA(cudaEventDestroy(start)); dydv = dydv.to(vx_type);
CHECK_CUDA(cudaEventDestroy(stop)); dydq = dydq.to(qy_type);
C10_CUDA_KERNEL_LAUNCH_CHECK(); // permute back to original layout
if(!kx_is_channels_last){
// Permute outputs back to memory layout given by input. if input had channels dydk = dydk.to(kx_type, at::MemoryFormat::Contiguous);
// first, leave it in that layout, otherwise permute layout back to [batch, } else {
// channel, ho, wo] dydk = dydk.to(kx_type);
if (!k_channel_first) dydk = dydk.contiguous(); }
if (!v_channel_first) dydv = dydv.contiguous(); if(!vx_is_channels_last){
if (!q_channel_first) dydq = dydq.contiguous(); dydv = dydv.to(vx_type, at::MemoryFormat::Contiguous);
} else {
// printf("dydk strides:["); dydv = dydv.to(vx_type);
// for(auto& stride : dydk.strides()) { }
// printf("%ld,", stride); if(!qy_is_channels_last) {
// } dydq = dydq.to(qy_type, at::MemoryFormat::Contiguous);
// printf("]\n"); } else {
// cudaDeviceSynchronize(); dydq = dydq.to(qy_type)
// delete permute_output_timer; }
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq); // printf("dydk strides: [");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydk, dydv, dydq);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment