Apply dropout scaling to dQ and dK instead of to V (in bwd)
Theoretically this might have lower numerical error since the scaling is in fp32 instead of fp16 (not sure, I haven't thought too carefully about it). However, in practice, the numerical errors seem about the same.
Showing
Please register or sign in to comment