Unverified Commit 5e3f8fff authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix some audio tests (#3841)

* Fix some audio tests

* make style

* fix

* make style
parent 5df2acf7
......@@ -36,7 +36,7 @@ from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils import is_xformers_available, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
......@@ -361,9 +361,15 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
@slow
# @require_torch_gpu
class AudioLDMPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
......
......@@ -640,7 +640,9 @@ class PipelineTesterMixin:
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass()
def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True, expected_max_diff=1e-4):
def _test_xformers_attention_forwardGenerator_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-4
):
if not self.test_xformers_attention:
return
......@@ -660,7 +662,8 @@ class PipelineTesterMixin:
max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
def test_progress_bar(self):
components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment