Commit 6613a8c7 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make CI happy

parent d9c449ea
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import gc import gc
import math import math
import os
import tracemalloc import tracemalloc
import unittest import unittest
...@@ -270,13 +269,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -270,13 +269,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing(self): def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
torch.use_deterministic_algorithms(True)
# from torch docs: "A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater,
# unless the environment variable CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8 is set."
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
...@@ -313,10 +305,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -313,10 +305,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
for name in grad_checkpointed: for name in grad_checkpointed:
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
# disable deterministic behavior for gradient checkpointing
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
torch.use_deterministic_algorithms(False)
# TODO(Patrick) - Re-add this test after having cleaned up LDM # TODO(Patrick) - Re-add this test after having cleaned up LDM
# def test_output_pretrained_spatial_transformer(self): # def test_output_pretrained_spatial_transformer(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment