Unverified Commit e62804ff authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable bria integration test on xpu, passed (#12214)


Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
parent bb1d9a8b
...@@ -28,10 +28,10 @@ from diffusers import ( ...@@ -28,10 +28,10 @@ from diffusers import (
) )
from diffusers.pipelines.bria import BriaPipeline from diffusers.pipelines.bria import BriaPipeline
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_accelerator, require_torch_accelerator,
require_torch_gpu,
slow, slow,
torch_device, torch_device,
) )
...@@ -149,7 +149,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -149,7 +149,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert (output_height, output_width) == (expected_height, expected_width) assert (output_height, output_width) == (expected_height, expected_width)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator @require_torch_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2): def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components() components = self.get_dummy_components()
for name, module in components.items(): for name, module in components.items():
...@@ -237,7 +237,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -237,7 +237,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class BriaPipelineSlowTests(unittest.TestCase): class BriaPipelineSlowTests(unittest.TestCase):
pipeline_class = BriaPipeline pipeline_class = BriaPipeline
repo_id = "briaai/BRIA-3.2" repo_id = "briaai/BRIA-3.2"
...@@ -245,12 +245,12 @@ class BriaPipelineSlowTests(unittest.TestCase): ...@@ -245,12 +245,12 @@ class BriaPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0): def get_inputs(self, device, seed=0):
generator = torch.Generator(device="cpu").manual_seed(seed) generator = torch.Generator(device="cpu").manual_seed(seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment