"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "759ea587082aa0e77449952d8f3523f28ddc61f3"
Commit ec3050b1 authored by Max Rietmann's avatar Max Rietmann
Browse files

Optimize bwd kernel: incremental qdot_max and alpha/integral/etc

Leverage the same qdotk_max "trick" for the backward kernel. This avoids 1 loop
and saves about 20% of performance.
parent 584e1bd6
...@@ -159,21 +159,7 @@ __launch_bounds__(BDIM_X) ...@@ -159,21 +159,7 @@ __launch_bounds__(BDIM_X)
const int64_t rend = psi_row_offset[ho+1]; const int64_t rend = psi_row_offset[ho+1];
const int rlen = rend - rbeg; const int rlen = rend - rbeg;
// First pass: find qdotk_max // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
for (int off = 0; off < rlen; off++) { for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off]; const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in; const int hi = col / nlon_in;
...@@ -186,15 +172,18 @@ __launch_bounds__(BDIM_X) ...@@ -186,15 +172,18 @@ __launch_bounds__(BDIM_X)
} }
qdotk = __warp_sum_cub(qdotk); qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv); gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; float qdotk_max_tmp = max(qdotk_max, qdotk);
alpha_sum += alpha_inz; float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
integral += alpha_inz * gdotv; float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip]; float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] += alpha_inz * kxval; sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv; sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv; sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
} }
qdotk_max = qdotk_max_tmp;
} }
integral /= alpha_sum; integral /= alpha_sum;
...@@ -330,7 +319,13 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens ...@@ -330,7 +319,13 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds); // s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment