"src/vscode:/vscode.git/clone" did not exist on "9bbd39b74185e0e17a2abfa3aa1ef9700b737600"
Unverified Commit f3e13104 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

reset deterministic in tearDownClass (#11785)



* reset deterministic in tearDownClass
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix deterministic setting
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 87f83d3d
...@@ -98,8 +98,15 @@ class Base4bitTests(unittest.TestCase): ...@@ -98,8 +98,15 @@ class Base4bitTests(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
@classmethod
def tearDownClass(cls):
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self): def get_dummy_inputs(self):
prompt_embeds = load_pt( prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
......
...@@ -99,8 +99,15 @@ class Base8bitTests(unittest.TestCase): ...@@ -99,8 +99,15 @@ class Base8bitTests(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
@classmethod
def tearDownClass(cls):
if not cls.is_deterministic_enabled:
torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self): def get_dummy_inputs(self):
prompt_embeds = load_pt( prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment