"...git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "0f966217e550f58dd92cac9ee802e5ad99ea91f6"
Unverified Commit 89177694 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

attend and excite tests disable determinism on the class level (#3478)

parent 49b7ccfb
...@@ -34,7 +34,6 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMix ...@@ -34,7 +34,6 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMix
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(False)
@skip_mps @skip_mps
...@@ -47,6 +46,19 @@ class StableDiffusionAttendAndExcitePipelineFastTests( ...@@ -47,6 +46,19 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
# Attend and excite requires being able to run a backward pass at
# inference time. There's no deterministic backward operator for pad
@classmethod
def setUpClass(cls):
super().setUpClass()
torch.use_deterministic_algorithms(False)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
torch.use_deterministic_algorithms(True)
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
...@@ -171,6 +183,19 @@ class StableDiffusionAttendAndExcitePipelineFastTests( ...@@ -171,6 +183,19 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
@require_torch_gpu @require_torch_gpu
@slow @slow
class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
# Attend and excite requires being able to run a backward pass at
# inference time. There's no deterministic backward operator for pad
@classmethod
def setUpClass(cls):
super().setUpClass()
torch.use_deterministic_algorithms(False)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
torch.use_deterministic_algorithms(True)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment