Unverified Commit 4c52982a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Tests] Add MPS skip decorator (#2362)



* finish

* Apply suggestions from code review

* fix indent and import error in test_stable_diffusion_depth

---------
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
parent 2a49fac8
...@@ -79,6 +79,7 @@ if is_torch_available(): ...@@ -79,6 +79,7 @@ if is_torch_available():
parse_flag_from_env, parse_flag_from_env,
print_tensor_test, print_tensor_test,
require_torch_gpu, require_torch_gpu,
skip_mps,
slow, slow,
torch_all_close, torch_all_close,
torch_device, torch_device,
......
...@@ -163,6 +163,11 @@ def require_torch_gpu(test_case): ...@@ -163,6 +163,11 @@ def require_torch_gpu(test_case):
) )
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
def require_flax(test_case): def require_flax(test_case):
""" """
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
......
...@@ -39,9 +39,8 @@ from diffusers import ( ...@@ -39,9 +39,8 @@ from diffusers import (
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils import floats_tensor, is_accelerate_available, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from diffusers.utils.testing_utils import require_torch_gpu
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -49,7 +48,7 @@ from ...test_pipelines_common import PipelineTesterMixin ...@@ -49,7 +48,7 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") @skip_mps
class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionDepth2ImgPipeline pipeline_class = StableDiffusionDepth2ImgPipeline
test_save_load_optional_components = False test_save_load_optional_components = False
...@@ -154,7 +153,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te ...@@ -154,7 +153,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
} }
return inputs return inputs
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
def test_save_load_local(self): def test_save_load_local(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
...@@ -248,7 +246,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te ...@@ -248,7 +246,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
max_diff = np.abs(output_with_offload - output_without_offload).max() max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results") self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
def test_dict_tuple_outputs_equivalent(self): def test_dict_tuple_outputs_equivalent(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
...@@ -265,7 +262,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te ...@@ -265,7 +262,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
max_diff = np.abs(output - output_tuple).max() max_diff = np.abs(output - output_tuple).max()
self.assertLess(max_diff, 1e-4) self.assertLess(max_diff, 1e-4)
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
def test_progress_bar(self): def test_progress_bar(self):
super().test_progress_bar() super().test_progress_bar()
......
...@@ -23,7 +23,7 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni ...@@ -23,7 +23,7 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
...@@ -349,7 +349,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -349,7 +349,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check. # because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu" test_max_difference = torch_device == "cpu"
...@@ -357,7 +357,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -357,7 +357,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Overriding PipelineTesterMixin::test_inference_batch_single_identical # Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check. # because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu" test_max_difference = torch_device == "cpu"
relax_max_difference = True relax_max_difference = True
...@@ -374,15 +374,15 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -374,15 +374,15 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else: else:
self._test_inference_batch_consistent() self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_dict_tuple_outputs_equivalent(self): def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent() return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_save_load_local(self): def test_save_load_local(self):
return super().test_save_load_local() return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
return super().test_save_load_optional_components() return super().test_save_load_optional_components()
......
...@@ -37,7 +37,7 @@ from diffusers import ( ...@@ -37,7 +37,7 @@ from diffusers import (
) )
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, require_torch_gpu from diffusers.utils.testing_utils import load_image, require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
...@@ -470,7 +470,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -470,7 +470,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check. # because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu" test_max_difference = torch_device == "cpu"
...@@ -478,7 +478,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -478,7 +478,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
# Overriding PipelineTesterMixin::test_inference_batch_single_identical # Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check. # because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu" test_max_difference = torch_device == "cpu"
relax_max_difference = True relax_max_difference = True
...@@ -495,15 +495,15 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -495,15 +495,15 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
else: else:
self._test_inference_batch_consistent() self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_dict_tuple_outputs_equivalent(self): def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent() return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_save_load_local(self): def test_save_load_local(self):
return super().test_save_load_local() return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent") @skip_mps
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
return super().test_save_load_optional_components() return super().test_save_load_optional_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment