Commit 3d06f4da authored by Max Rietmann's avatar Max Rietmann
Browse files

Removed unnecessary code in fwd and bwd kernels.

Also: Made fwd kernel use modified memory layout with standard shape
parent 6512d042
...@@ -83,18 +83,6 @@ private: ...@@ -83,18 +83,6 @@ private:
std::chrono::high_resolution_clock::time_point start_; std::chrono::high_resolution_clock::time_point start_;
}; };
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
static __device__ float __warp_sum(float val) { static __device__ float __warp_sum(float val) {
#pragma unroll #pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) { for(int i = WARP_SIZE/2; i; i /= 2) {
...@@ -105,7 +93,7 @@ static __device__ float __warp_sum(float val) { ...@@ -105,7 +93,7 @@ static __device__ float __warp_sum(float val) {
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar // easier to understand version of manual shfl_xor_sync, performance appears similar
__device__ float __warp_sum_cub(float val) { static __device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp // use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage; __shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
...@@ -303,9 +291,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -303,9 +291,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePop(); nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"); nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydkP = torch::zeros_like(qyP); auto dydk = torch::zeros_like(qyP);
auto dydvP = torch::zeros_like(qyP); auto dydv = torch::zeros_like(qyP);
auto dydqP = torch::zeros_like(qyP); auto dydq = torch::zeros_like(qyP);
// print strdie of dydkP, dydvP, dydqP // print strdie of dydkP, dydvP, dydqP
nvtxRangePop(); nvtxRangePop();
...@@ -329,9 +317,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -329,9 +317,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydkP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydvP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydqP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(), psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>()); quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
...@@ -351,13 +339,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -351,13 +339,9 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// Permute outputs back to memory layout given by input. if input had channels // Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch, // first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo] // channel, ho, wo]
at::Tensor dydk, dydv, dydq; if(!k_channel_first) dydk = dydk.contiguous();
if(!k_channel_first) dydk = dydkP.contiguous(); if(!v_channel_first) dydv = dydv.contiguous();
else dydk = dydkP; if(!q_channel_first) dydq = dydq.contiguous();
if(!v_channel_first) dydv = dydvP.contiguous();
else dydv = dydvP;
if(!q_channel_first) dydq = dydqP.contiguous();
else dydq = dydqP;
// printf("dydk strides:["); // printf("dydk strides:[");
// for(auto& stride : dydk.strides()) { // for(auto& stride : dydk.strides()) {
......
...@@ -65,137 +65,6 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>; ...@@ -65,137 +65,6 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
}} }}
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
__global__ void s2_attention_kernel(int num_channels, int nlon_in, int nlat_out,
int nlon_out,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> kx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> vx,
const torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> qy,
torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> y,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_col_idx,
const torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> psi_row_offset,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights)
{
// shared memory
extern __shared__ float sharedMem[];
float *sh_alpha_sum = (float *)&sharedMem;
float* sh_qdotk_max = (float*)&sharedMem[1];
float* sh_qy_ho_wo = (float *)&sharedMem[2];
if (threadIdx.x == 0) {
sh_qdotk_max[0] = std::numeric_limits<float>::lowest();
sh_alpha_sum[0] = 0.0;
}
__syncthreads();
int ho = blockIdx.x;
int wo = blockIdx.y;
int batch_b = blockIdx.z;
// load qy channels into shared memory
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
sh_qy_ho_wo[channel_idx] = qy[batch_b][channel_idx][ho][wo];
y[batch_b][channel_idx][ho][wo] = 0.0;
}
__syncthreads();
int psi_offset = psi_row_offset[ho];
int psi_nnz_ho = psi_row_offset[ho + 1] - psi_offset;
float qdotk_max = std::numeric_limits<float>::lowest();
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz >= psi_nnz_ho) break;
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
qdotk_max = std::max(qdotk_max, qdotk);
}
// collect thread-local qdotk max
atomicMax(&sh_qdotk_max[0], qdotk_max);
__syncthreads();
// "broadcast" qdotk_max back into all thread-local registers
qdotk_max = sh_qdotk_max[0];
float alpha_sum = 0.0f;
for(int psi_block=0; psi_block<(psi_nnz_ho/blockDim.x)+1; psi_block++) {
int idz = psi_block*blockDim.x + threadIdx.x;
float alpha_inz = 0.0;
// skip if index >= length of psi_idx because last loop iteration will have extra threads
if(idz < psi_nnz_ho) {
int nz_col_idx = psi_col_idx[psi_offset+idz];
// compute input indices from psi datastructure
int hi = nz_col_idx / nlon_in;
// account for output shift and ensure positive index due to circular condition
// int wi = (nz_col_idx % nlon_in - wo) % nlon_in;
int wi = nz_col_idx % nlon_in;
int wip = (wi + wo) % nlon_in;
// softmax numerator with minus qdotk_max to avoid numerical overflow.
// Because qdotk_max is in both numerator and denominator (due to
// alpha_sum), it doesn't effect the solution other than removing overflow
// correlation Q&K (dot-product Q.K)
float qdotk = 0.0;
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
qdotk += sh_qy_ho_wo[channel_idx]*kx[batch_b][channel_idx][hi][wip];
}
alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
// thread-local sum alpha
alpha_sum += alpha_inz;
// multiply alpha, vx, and quadrature weights
for(int channel_idx = 0; channel_idx<num_channels; channel_idx++) {
atomicAdd(&y[batch_b][channel_idx][ho][wo], alpha_inz * vx[batch_b][channel_idx][hi][wip]);
}
}
}
// collect all alpha_sum across threads
atomicAdd(&sh_alpha_sum[0], alpha_sum);
__syncthreads();
// rebroadcast sum to all threads
alpha_sum = sh_alpha_sum[0];
// divide output by alpha_sum
for(int channel_block_i = 0; channel_block_i<(num_channels/blockDim.x)+1; channel_block_i++) {
int channel_idx = channel_block_i*blockDim.x + threadIdx.x;
if(channel_idx >= num_channels) break;
y[batch_b][channel_idx][ho][wo] /= alpha_sum;
}
}
static __device__ float __warp_sum(float val) { static __device__ float __warp_sum(float val) {
#pragma unroll #pragma unroll
for(int i = WARP_SIZE/2; i; i /= 2) { for(int i = WARP_SIZE/2; i; i /= 2) {
...@@ -204,11 +73,24 @@ static __device__ float __warp_sum(float val) { ...@@ -204,11 +73,24 @@ static __device__ float __warp_sum(float val) {
return val; return val;
} }
// easier to understand version of manual shfl_xor_sync, performance appears similar
static __device__ float __warp_sum_cub(float val) {
// use cub to reduce within a warp
__shared__ typename cub::WarpReduce<float>::TempStorage temp_storage;
// 1. Compute sum (initially only in lane 0)
float sum = cub::WarpReduce<float>(temp_storage).Sum(val);
// 2. Broadcast sum to all threads
sum = __shfl_sync(0xFFFFFFFF, sum, 0);
return sum;
}
// one warp per (ho,wo) // one warp per (ho,wo)
template<int BDIM_X> template<int BDIM_X>
__global__ __global__
__launch_bounds__(BDIM_X) __launch_bounds__(BDIM_X)
void s2_attention_kernel_mbT(int num_channels, void s2_attention_kernel(int num_channels,
int nlon_in, int nlon_in,
int nlat_out, int nlat_out,
int nlon_out, int nlon_out,
...@@ -263,10 +145,10 @@ __launch_bounds__(BDIM_X) ...@@ -263,10 +145,10 @@ __launch_bounds__(BDIM_X)
float qdotk = 0.0f; float qdotk = 0.0f;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][ wo][chan]* qdotk += qy[batchId][chan][ho][ wo]*
kx[batchId][hi][wip][chan]; kx[batchId][chan][hi][wip];
} }
qdotk = __warp_sum(qdotk); qdotk = __warp_sum_cub(qdotk);
float qdotk_max_tmp; float qdotk_max_tmp;
float alpha; float alpha;
...@@ -279,7 +161,7 @@ __launch_bounds__(BDIM_X) ...@@ -279,7 +161,7 @@ __launch_bounds__(BDIM_X)
alpha_sum = alpha + alpha_sum*exp_save; alpha_sum = alpha + alpha_sum*exp_save;
for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for(int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
shy[chan] = shy[chan]*exp_save + vx[batchId][hi][wip][chan]*alpha; shy[chan] = shy[chan]*exp_save + vx[batchId][chan][hi][wip]*alpha;
} }
qdotk_max = qdotk_max_tmp; qdotk_max = qdotk_max_tmp;
} }
...@@ -317,12 +199,35 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -317,12 +199,35 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
auto k_channel_first = kx.strides()[1] == 1;
auto v_channel_first = vx.strides()[1] == 1;
auto q_channel_first = qy.strides()[1] == 1;
// transpose inputs so that channels are in the last dimension, allowing for // transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access // coalesced memory access
nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs"); nvtxRangePush("s2_attention_fwd_kernel_mbT permute inputs");
torch::Tensor kxP = kx.permute({0,2,3,1}).contiguous(); //Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
torch::Tensor vxP = vx.permute({0,2,3,1}).contiguous(); auto kxP = at::Tensor();
torch::Tensor qyP = qy.permute({0, 2, 3, 1}).contiguous(); if (!k_channel_first) {
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP = kx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
kxP = kx;
}
auto vxP = at::Tensor();
if (!v_channel_first) {
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP = vx.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
vxP = vx;
}
auto qyP = at::Tensor();
if (!q_channel_first) {
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP = qy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
} else {
qyP = qy;
}
cudaDeviceSynchronize(); cudaDeviceSynchronize();
nvtxRangePop(); nvtxRangePop();
torch::Tensor y = torch::empty_like(qy); torch::Tensor y = torch::empty_like(qy);
...@@ -338,7 +243,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -338,7 +243,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA(cudaEventCreate(&stop)); CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start, stream)); CHECK_CUDA(cudaEventRecord(start, stream));
s2_attention_kernel_mbT<THREADS> s2_attention_kernel<THREADS>
<<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out, <<<grid, block, shared_size, stream>>>(uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(), vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
...@@ -355,6 +260,9 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, ...@@ -355,6 +260,9 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
// match output layout to input layout
if (!q_channel_first) y = y.contiguous();
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
return y; return y;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment