Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e62804ff
Unverified
Commit
e62804ff
authored
Aug 21, 2025
by
Yao Matrix
Committed by
GitHub
Aug 22, 2025
Browse files
enable bria integration test on xpu, passed (#12214)
Signed-off-by:
YAO Matrix
<
matrix.yao@intel.com
>
parent
bb1d9a8b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
tests/pipelines/bria/test_pipeline_bria.py
tests/pipelines/bria/test_pipeline_bria.py
+6
-6
No files found.
tests/pipelines/bria/test_pipeline_bria.py
View file @
e62804ff
...
@@ -28,10 +28,10 @@ from diffusers import (
...
@@ -28,10 +28,10 @@ from diffusers import (
)
)
from
diffusers.pipelines.bria
import
BriaPipeline
from
diffusers.pipelines.bria
import
BriaPipeline
from
diffusers.utils.testing_utils
import
(
from
diffusers.utils.testing_utils
import
(
backend_empty_cache
,
enable_full_determinism
,
enable_full_determinism
,
numpy_cosine_similarity_distance
,
numpy_cosine_similarity_distance
,
require_accelerator
,
require_torch_accelerator
,
require_torch_gpu
,
slow
,
slow
,
torch_device
,
torch_device
,
)
)
...
@@ -149,7 +149,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -149,7 +149,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert
(
output_height
,
output_width
)
==
(
expected_height
,
expected_width
)
assert
(
output_height
,
output_width
)
==
(
expected_height
,
expected_width
)
@
unittest
.
skipIf
(
torch_device
not
in
[
"cuda"
,
"xpu"
],
reason
=
"float16 requires CUDA or XPU"
)
@
unittest
.
skipIf
(
torch_device
not
in
[
"cuda"
,
"xpu"
],
reason
=
"float16 requires CUDA or XPU"
)
@
require_accelerator
@
require_
torch_
accelerator
def
test_save_load_float16
(
self
,
expected_max_diff
=
1e-2
):
def
test_save_load_float16
(
self
,
expected_max_diff
=
1e-2
):
components
=
self
.
get_dummy_components
()
components
=
self
.
get_dummy_components
()
for
name
,
module
in
components
.
items
():
for
name
,
module
in
components
.
items
():
...
@@ -237,7 +237,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -237,7 +237,7 @@ class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@
slow
@
slow
@
require_torch_
gpu
@
require_torch_
accelerator
class
BriaPipelineSlowTests
(
unittest
.
TestCase
):
class
BriaPipelineSlowTests
(
unittest
.
TestCase
):
pipeline_class
=
BriaPipeline
pipeline_class
=
BriaPipeline
repo_id
=
"briaai/BRIA-3.2"
repo_id
=
"briaai/BRIA-3.2"
...
@@ -245,12 +245,12 @@ class BriaPipelineSlowTests(unittest.TestCase):
...
@@ -245,12 +245,12 @@ class BriaPipelineSlowTests(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
().
setUp
()
super
().
setUp
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
backend_
empty_cache
(
torch_device
)
def
tearDown
(
self
):
def
tearDown
(
self
):
super
().
tearDown
()
super
().
tearDown
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
backend_
empty_cache
(
torch_device
)
def
get_inputs
(
self
,
device
,
seed
=
0
):
def
get_inputs
(
self
,
device
,
seed
=
0
):
generator
=
torch
.
Generator
(
device
=
"cpu"
).
manual_seed
(
seed
)
generator
=
torch
.
Generator
(
device
=
"cpu"
).
manual_seed
(
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment