[PyTorch] Fix FP8 logic related to FA2/FA3 (#1141)
* fix FP8 logic when FA3 is not installed Signed-off-by:Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor tweak to make logic more explicit Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * limit FA3 warning to Hopper and NVTE_FLASH_ATTN=1 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prefer fused attn for FP8 Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by:
Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Showing
Please register or sign in to comment