Unverified Commit f8cb598c authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Only disable Flash Attention in Userbuffers test on SM 8.0 (#2401)



Only disable Flash Attention in Userbuffers test on A100
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent a75da0ca
......@@ -120,6 +120,10 @@ def _run_layer_with_overlap(
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
if te.get_device_compute_capability() <= (8, 0):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os.environ["NVTE_FLASH_ATTN"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment