Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cd8b7507
Unverified
Commit
cd8b7507
authored
Apr 18, 2023
by
YiYi Xu
Committed by
GitHub
Apr 18, 2023
Browse files
speed up attend-and-excite fast tests (#3079)
parent
3b641eab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
...le_diffusion_2/test_stable_diffusion_attend_and_excite.py
+9
-4
No files found.
tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
View file @
cd8b7507
...
...
@@ -44,7 +44,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt
torch
.
manual_seed
(
0
)
unet
=
UNet2DConditionModel
(
block_out_channels
=
(
32
,
64
),
layers_per_block
=
2
,
layers_per_block
=
1
,
sample_size
=
32
,
in_channels
=
4
,
out_channels
=
4
,
...
...
@@ -111,7 +111,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt
"prompt"
:
"a cat and a frog"
,
"token_indices"
:
[
2
,
5
],
"generator"
:
generator
,
"num_inference_steps"
:
2
,
"num_inference_steps"
:
1
,
"guidance_scale"
:
6.0
,
"output_type"
:
"numpy"
,
"max_iter_to_alter"
:
2
,
...
...
@@ -132,13 +132,18 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
self
.
assertEqual
(
image
.
shape
,
(
1
,
64
,
64
,
3
))
expected_slice
=
np
.
array
([
0.5743
,
0.6081
,
0.4975
,
0.5021
,
0.5441
,
0.4699
,
0.4988
,
0.4841
,
0.4851
])
expected_slice
=
np
.
array
(
[
0.63905364
,
0.62897307
,
0.48599017
,
0.5133624
,
0.5550048
,
0.45769516
,
0.50326973
,
0.5023139
,
0.45384496
]
)
max_diff
=
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
self
.
assertLessEqual
(
max_diff
,
1e-3
)
def
test_inference_batch_consistent
(
self
):
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
self
.
_test_inference_batch_consistent
(
batch_sizes
=
[
2
,
4
])
self
.
_test_inference_batch_consistent
(
batch_sizes
=
[
1
,
2
])
def
test_inference_batch_single_identical
(
self
):
self
.
_test_inference_batch_single_identical
(
batch_size
=
2
)
@
require_torch_gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment