Commit 31920dda authored by Tri Dao's avatar Tri Dao
Browse files

Fix typo with lse_max == -INFINITY

parent 8a326bbc
......@@ -1118,7 +1118,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
MaxOp<float> max_op;
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
float lse_sum = expf(lse_accum(0) - lse_max);
#pragma unroll
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment