Unverified Commit 54bc882d authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

`mps` test fixes (#2470)

* Skip variant tests (UNet1d, UNetRL) on mps.

mish op not yet supported.

* Exclude a couple of panorama tests on mps

They are too slow for fast CI.

* Exclude mps panorama from more tests.

* mps: exclude all fast panorama tests as they keep failing.
parent 589faa8c
...@@ -66,6 +66,10 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -66,6 +66,10 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
super().test_from_save_pretrained() super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
super().test_model_from_pretrained() super().test_model_from_pretrained()
...@@ -186,6 +190,10 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -186,6 +190,10 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
super().test_from_save_pretrained() super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
super().test_model_from_pretrained() super().test_model_from_pretrained()
......
...@@ -30,7 +30,7 @@ from diffusers import ( ...@@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import slow, torch_device from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -38,6 +38,7 @@ from ...test_pipelines_common import PipelineTesterMixin ...@@ -38,6 +38,7 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps
class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPanoramaPipeline pipeline_class = StableDiffusionPanoramaPipeline
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment