Unverified Commit 5da00de0 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #80 from rietmann-nv/mr/optimized-bwd-qdotk_max

Optimize bwd kernel: incremental qdot_max and alpha/integral/etc
parents 76836abf 65058287
...@@ -159,21 +159,7 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( ...@@ -159,21 +159,7 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
const int64_t rend = psi_row_offset[ho + 1]; const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg; const int rlen = rend - rbeg;
// First pass: find qdotk_max // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
for (int off = 0; off < rlen; off++) { for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off]; const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
...@@ -186,15 +172,18 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( ...@@ -186,15 +172,18 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
} }
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv); gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; float qdotk_max_tmp = max(qdotk_max, qdotk);
alpha_sum += alpha_inz; float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
integral += alpha_inz * gdotv; float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip]; float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] += alpha_inz * kxval; sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv; sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv; sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
} }
qdotk_max = qdotk_max_tmp;
} }
integral /= alpha_sum; integral /= alpha_sum;
...@@ -323,8 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -323,8 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms // s2_attention_bwd_kernel execution time: 50.724865 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds); // [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment