Theoretically this might have lower numerical error since the scaling is in fp32 instead of fp16 (not sure, I haven't thought too carefully about it). However, in practice, the numerical errors seem about the same.
This speeds up the fwd by 1.5x.
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO have 2 buffers)