Unverified Commit 26370b11 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[PyT] Bump the min version expected to supported FP8 current scaling...


[PyT] Bump the min version expected to supported FP8 current scaling determinism on Blackwell (#2316)

* Bump the min version expected to supported FP8 cs det on Blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Disable fused attn for cudnn < 9.14 for FP8 CS. Disable fused attn for cudnn < 9.18 for FP8 deterministic CS
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0e80c847
......@@ -477,9 +477,21 @@ def get_attention_backend(
if device_compute_capability < (10, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
elif cudnn_version < (9, 14, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
use_fused_attention = False
# TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling
# determinism for Blackwell
else:
if cudnn_version < (9, 14, 0):
logger.debug(
"Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0"
)
use_fused_attention = False
else:
if deterministic and cudnn_version < (9, 18, 0):
logger.debug(
"Disabling FusedAttention for FP8 current scaling requiring determinism"
" with cuDNN < 9.18.0"
)
use_fused_attention = False
if device_compute_capability == (12, 0):
if use_flash_attention:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment