Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
187de443
Commit
187de443
authored
Nov 09, 2022
by
Patrick von Platen
Browse files
Fix device on save/load tests
parent
7d0c2729
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
6 deletions
+22
-6
tests/test_pipelines.py
tests/test_pipelines.py
+22
-6
No files found.
tests/test_pipelines.py
View file @
187de443
...
@@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase):
...
@@ -102,8 +102,12 @@ class DownloadTests(unittest.TestCase):
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
pipe_2
=
pipe_2
.
to
(
torch_device
)
pipe_2
=
pipe_2
.
to
(
torch_device
)
generator_2
=
generator
.
manual_seed
(
0
)
if
torch_device
==
"mps"
:
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
...
@@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase):
...
@@ -124,8 +128,14 @@ class DownloadTests(unittest.TestCase):
pipe
.
save_pretrained
(
tmpdirname
)
pipe
.
save_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
,
safety_checker
=
None
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
,
safety_checker
=
None
)
pipe_2
=
pipe_2
.
to
(
torch_device
)
pipe_2
=
pipe_2
.
to
(
torch_device
)
generator_2
=
generator
.
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
...
@@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase):
...
@@ -144,8 +154,14 @@ class DownloadTests(unittest.TestCase):
pipe
.
save_pretrained
(
tmpdirname
)
pipe
.
save_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
)
pipe_2
=
pipe_2
.
to
(
torch_device
)
pipe_2
=
pipe_2
.
to
(
torch_device
)
generator_2
=
generator
.
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
if
torch_device
==
"mps"
:
# device type MPS is not supported for torch.Generator() api.
generator
=
torch
.
manual_seed
(
0
)
else
:
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment