[PyTorch] Fix fused attention backward's FP8 dtypes (#1566)
* fix dtypes in fused attn bwd for FP8 Signed-off-by:Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add comments for dtypes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove redundant qkv_dtype in fwd Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove Nones in bwd returns Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Showing
Please register or sign in to comment