Commit 191ba149 authored by Max Rietmann's avatar Max Rietmann
Browse files

Re-introduce inline softmax from main

parent 3dd35b45
......@@ -159,21 +159,7 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
const int64_t rend = psi_row_offset[ho + 1];
const int rlen = rend - rbeg;
// First pass: find qdotk_max
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
const int wi = col - (hi * nlon_in);
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for (int off = 0; off < rlen; off++) {
const int64_t col = psi_col_idx[rbeg + off];
const int hi = col / nlon_in;
......@@ -186,15 +172,18 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
float qdotk_max_tmp = max(qdotk_max, qdotk);
float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi];
float max_correction = expf(qdotk_max - qdotk_max_tmp);
alpha_sum = alpha_sum * max_correction + alpha_inz;
integral = integral * max_correction + alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval;
sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv;
sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv;
}
qdotk_max = qdotk_max_tmp;
}
integral /= alpha_sum;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment