Commit 6512d042 authored by Max Rietmann's avatar Max Rietmann
Browse files

Removed all stale backwards kernel code

Also match the gradient output to the input, in terms of memory layout
parent 4096e64b
...@@ -289,7 +289,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -289,7 +289,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
v_gpu.requires_grad = True v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu) out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0").to(memory_format=torch.channels_last) out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0")
time_backward_start = torch.cuda.Event(enable_timing=True) time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True) time_backward_end = torch.cuda.Event(enable_timing=True)
......
...@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -49,30 +49,3 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
at::Tensor psi_col_idx, at::Tensor psi_col_idx,
at::Tensor psi_row_off, at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out); int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
torch::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out);
...@@ -116,634 +116,14 @@ __device__ float __warp_sum_cub(float val) { ...@@ -116,634 +116,14 @@ __device__ float __warp_sum_cub(float val) {
return sum; return sum;
} }
__global__ void // This kernel computes the backward pass for the S2 attention mechanism, using
s2_attention_bwd_dv_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out, // shared memory as a cache and one warp per output point, warp-parallel over
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx, // channels, which should be layed out in the fastest dimension for coalesced
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx, // memory access.
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem; // 1
float* sh_qdotk_max = (float*)&sharedMem[1]; // 1
float* sh_qy_ho_wo = (float*)&sharedMem[2]; // num_channels
if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0;
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
// form alpha & sum alpha
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
}
// collect thread-local alpha_sum
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
// alpha * dy * omega / alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, dy, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydv[batch_b][channel_idx][hi][wip], (alpha_inz/alpha_sum) * dy[batch_b][channel_idx][ho][wo]);
}
}
}
at::Tensor s2_attention_bwd_dv_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydv = torch::zeros_like(vx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dv_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydv;
}
__global__ void
s2_attention_bwd_dk_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_integral = (float *)&sharedMem[1 + num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 2 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = max(qdotk_max, qdotk);
}
// compute max over all threads
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha & integral
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
}
// block sum thread-local alpha_sum and integral
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finish integral computation
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// broadcast sum and integral back to thread-local registers
integral = sh_integral[0];
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz/alpha_sum) * (gdotv - integral));
}
}
__syncthreads();
}
__global__ void
s2_attention_bwd_dq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float* sh_alpha_sum = (float*)&sharedMem;
float *sh_qy_ho_wo = (float *)&sharedMem[1];
float *sh_alpha_k = (float *)&sharedMem[1 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[1 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[1 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[1 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[1 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
}
// sum thread-local alpha_sums across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// "broadcast" alpha sum back to thread-local registers
alpha_sum = sh_alpha_sum[0];
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
}
__global__ void s2_attention_bwd_dkvq_kernel(int num_channels, int nlon_in, int nlat_out, int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>
dy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydk,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydv,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> dydq,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_integral = (float*)&sharedMem[1];
float *sh_qy_ho_wo = (float *)&sharedMem[2];
float *sh_alpha_k = (float *)&sharedMem[2 + num_channels];
float *sh_alpha_vw = (float *)&sharedMem[2 + 2*num_channels];
float *sh_alpha_kvw = (float *)&sharedMem[2 + 3*num_channels];
float *sh_dy_ho_wo = (float *)&sharedMem[2 + 4 * num_channels];
float *sh_qdotk_max = (float *)&sharedMem[2 + 5 * num_channels];
if (threadIdx.x == 0) {
sh_alpha_sum[0] = 0.0;
sh_integral[0] = 0.0;
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory and zero temporary variables
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
sh_dy_ho_wo[channel_idx] = dy[batch_b][channel_idx][ho][wo];
sh_alpha_k[channel_idx] = 0.0f;
sh_alpha_vw[channel_idx] = 0.0f;
sh_alpha_kvw[channel_idx] = 0.0f;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk, qdotk_max);
}
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0;
float integral = 0.0;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0f;
float gdotv = 0.0f;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx] * kx[batch_b][channel_idx][hi][wip];
}
// softmax numerator
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// sum alpha
alpha_sum += alpha_inz;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&sh_alpha_k[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip]);
atomicAdd(&sh_alpha_vw[channel_idx],
alpha_inz * gdotv);
atomicAdd(&sh_alpha_kvw[channel_idx],
alpha_inz * kx[batch_b][channel_idx][hi][wip] * gdotv);
}
integral += alpha_inz * gdotv;
}
// sum thread-local alpha_sums & integral across block
atomicAdd(&sh_alpha_sum[0], alpha_sum);
atomicAdd(&sh_integral[0], integral);
__syncthreads();
// finalize integral
if(threadIdx.x==0) sh_integral[0] /= sh_alpha_sum[0];
__syncthreads();
// "broadcast" alpha sum & integral back to thread-local registers
alpha_sum = sh_alpha_sum[0];
integral = sh_integral[0];
// dq
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if (channel_idx >= num_channels)
break;
dydq[batch_b][channel_idx][ho][wo] = (sh_alpha_kvw[channel_idx]*sh_alpha_sum[0] - sh_alpha_vw[channel_idx]*sh_alpha_k[channel_idx])/(alpha_sum*alpha_sum);
}
__syncthreads();
// dk & dv
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
float gdotv = 0.0;
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
gdotv += sh_dy_ho_wo[channel_idx] * vx[batch_b][channel_idx][hi][wip];
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// multiply alpha/sum_alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&dydk[batch_b][channel_idx][hi][wip],
sh_qy_ho_wo[channel_idx] * (alpha_inz / alpha_sum) *
(gdotv - integral));
atomicAdd(&dydv[batch_b][channel_idx][hi][wip],
(alpha_inz / alpha_sum) * sh_dy_ho_wo[channel_idx]);
}
}
__syncthreads();
}
// New kernel: s2_attention_bwd_dkvq_kernel_mbT
// This kernel assumes kx, vx, qy, dy, dydk, dydv, dydq are all [batch, ho, wo, channel] (transposed)
template<int BDIM_X> template<int BDIM_X>
__global__ __global__
__launch_bounds__(BDIM_X) __launch_bounds__(BDIM_X)
void s2_attention_bwd_dkvq_kernel_mbT( void s2_attention_bwd_dkvq_kernel(
int num_channels, int num_channels,
int nlon_in, int nlon_in,
int nlat_out, int nlat_out,
...@@ -859,116 +239,8 @@ __launch_bounds__(BDIM_X) ...@@ -859,116 +239,8 @@ __launch_bounds__(BDIM_X)
} }
} }
at::Tensor s2_attention_bwd_dk_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydk = torch::zeros_like(kx);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (2*uo_num_channels+3)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dk_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydk;
}
at::Tensor s2_attention_bwd_dq_cuda(at::Tensor kx,
at::Tensor vx,
at::Tensor qy,
at::Tensor dy,
at::Tensor quad_weights,
at::Tensor psi_col_idx,
at::Tensor psi_row_off,
int nlon_in, int nlat_out, int nlon_out) {
CHECK_CUDA_TENSOR(kx);
CHECK_CUDA_TENSOR(vx);
CHECK_CUDA_TENSOR(qy);
CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off);
CHECK_CUDA_TENSOR(dy);
auto stream = at::cuda::getCurrentCUDAStream().stream();
torch::Tensor dydq = torch::zeros_like(qy);
size_t uo_num_channels = kx.size(1);
size_t sharedMemSize = (5*uo_num_channels+2)*sizeof(float);
const int batch_size = kx.size(0);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
s2_attention_bwd_dq_kernel <<<gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return dydq;
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx,
at::Tensor qy, at::Tensor qy,
...@@ -988,83 +260,6 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -988,83 +260,6 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
// enum for which kernel version
enum KERNEL_VERSION {
OLD_VERSION = 0,
HOWO_WARP_VERSION = 2,
};
auto version = HOWO_WARP_VERSION;
// auto version = OLD_VERSION;
if (version == OLD_VERSION) {
// printf("old version\n");
torch::Tensor dydk = torch::zeros_like(qy);
torch::Tensor dydv = torch::zeros_like(qy);
torch::Tensor dydq = torch::zeros_like(qy);
size_t sharedMemSize = (6*uo_num_channels+3)*sizeof(float);
// cuda grid y,z size limitations
assert(nlon_out < 65535);
assert(batch_size < 65535);
// block-parallel over output points and batches
dim3 gridDim(nlat_out,nlon_out,batch_size);
// threads compute "blocks" of neighborhood and also "blocks" of channels
dim3 blockDim(256, 1, 1);
// Define CUDA event variables for timing
cudaEvent_t start_event, stop_event;
cudaEventCreate(&start_event);
cudaEventCreate(&stop_event);
// Record the start event
cudaEventRecord(start_event, stream);
s2_attention_bwd_dkvq_kernel<<<
gridDim, blockDim, sharedMemSize, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()
);
// Record the stop event
cudaEventRecord(stop_event, stream);
cudaEventSynchronize(stop_event);
// Calculate elapsed time
float kernel_time_ms;
cudaEventElapsedTime(&kernel_time_ms, start_event, stop_event);
// Output the result
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// Old bwd kernel execution time: 803.477 ms
// std::cout << "Old bwd kernel execution time: " << kernel_time_ms << " ms" << std::endl;
// Cleanup events
cudaEventDestroy(start_event);
cudaEventDestroy(stop_event);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(dydk, dydv, dydq);
} else if (version == HOWO_WARP_VERSION) {
// ScopeTimer timer("Full s2_attention_bwd_dkvq_kernel_mbT");
// Time this function via C++
auto k_channel_first = kx.strides()[1] == 1; auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1; auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1; auto q_channel_first = qy.strides()[1] == 1;
...@@ -1127,7 +322,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -1127,7 +322,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_bwd_dkvq_kernel_mbT<THREADS><<< s2_attention_bwd_dkvq_kernel<THREADS><<<
grid, block, shared_size, stream>>>( grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out, uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
...@@ -1153,18 +348,26 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -1153,18 +348,26 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to [batch, channel, ho, wo] // Permute outputs back to memory layout given by input. if input had channels
// nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output permutation"); // first, leave it in that layout, otherwise permute layout back to [batch,
// auto* permute_output_timer = new ScopeTimer("permute outputs"); // channel, ho, wo]
// auto dydk = dydkP.permute({0,3,1,2}).contiguous().permute({0,3,1,2}); at::Tensor dydk, dydv, dydq;
// auto dydv = dydvP.permute({0,3,1,2}).contiguous(); if(!k_channel_first) dydk = dydkP.contiguous();
// auto dydq = dydqP.permute({0, 3, 1, 2}).contiguous(); else dydk = dydkP;
if(!v_channel_first) dydv = dydvP.contiguous();
else dydv = dydvP;
if(!q_channel_first) dydq = dydqP.contiguous();
else dydq = dydqP;
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize(); // cudaDeviceSynchronize();
// delete permute_output_timer; // delete permute_output_timer;
// nvtxRangePop(); // nvtxRangePop();
return std::make_tuple(dydkP, dydvP, dydqP); return std::make_tuple(dydk, dydv, dydq);
} else {
throw std::runtime_error("Invalid kernel version specified");
}
} }
...@@ -33,10 +33,6 @@ ...@@ -33,10 +33,6 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2"); m.def("forward", &s2_attention_fwd_cuda, "(Local) Attention on S2");
m.def("backward_dk", &s2_attention_bwd_dk_cuda, "(Local) Attention gradient on S2 (gradient for k)");
m.def("backward_dv", &s2_attention_bwd_dv_cuda, "(Local) Attention gradient on S2 (gradient for v)");
m.def("backward_dq", &s2_attention_bwd_dq_cuda,
"(Local) Attention gradient on S2 (gradient for q)");
m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)"); m.def("backward_dkvq", &s2_attention_bwd_dkvq_cuda, "(Local) Attention gradient on S2 (gradient for k,v,&q)");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment