Unverified Commit 2236292a authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Disable fused attention in encoder tests for determinism (#2601)



disable fused attention in encoder tests for determinism
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 4df43dbe
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Encoder training on multi-GPU with tesnor parallelism""" """Encoder training on multi-GPU with tesnor parallelism"""
import argparse import argparse
import os
import unittest import unittest
from functools import partial from functools import partial
...@@ -489,6 +490,9 @@ class TestEncoder(unittest.TestCase): ...@@ -489,6 +490,9 @@ class TestEncoder(unittest.TestCase):
def setUp(self): def setUp(self):
"""Run 5 epochs for testing""" """Run 5 epochs for testing"""
# TODO(jberchtold): Remove once fused attention from cuDNN supports determinism on Blackwell
if "NVTE_FUSED_ATTN" not in os.environ:
os.environ["NVTE_FUSED_ATTN"] = "0"
self.args = encoder_parser(["--epochs", "5"]) self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment