Get P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary
Initially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout:
// We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load
Tile<Int<128>,Layout<Shape<_128,_2,_2>,Stride<_1,_256,_128>>,_16>{}// We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
boolhave_valid_indices=__any_sync(0xffffffff,li!=0);// Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if(!have_valid_indices){
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
// NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid.
// NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs.
// NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases