Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f1d4289b
Unverified
Commit
f1d4289b
authored
Oct 13, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 13, 2022
Browse files
[Flax] Add test (#824)
parent
323a9e1f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
8 deletions
+74
-8
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+0
-1
src/diffusers/utils/testing_utils.py
src/diffusers/utils/testing_utils.py
+12
-7
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+62
-0
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
f1d4289b
...
...
@@ -52,7 +52,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"np"
)
self
.
dtype
=
dtype
self
.
register_modules
(
...
...
src/diffusers/utils/testing_utils.py
View file @
f1d4289b
...
...
@@ -7,21 +7,26 @@ from distutils.util import strtobool
from
pathlib
import
Path
from
typing
import
Union
import
torch
import
PIL.Image
import
PIL.ImageOps
import
requests
from
packaging
import
version
from
.import_utils
import
is_flax_available
from
.import_utils
import
is_flax_available
,
is_torch_available
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
is_torch_higher_equal_than_1_12
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
>=
version
.
parse
(
"1.12"
)
if
is_torch_higher_equal_than_1_12
:
if
is_torch_available
():
import
torch
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
is_torch_higher_equal_than_1_12
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
>=
version
.
parse
(
"1.12"
)
if
is_torch_higher_equal_than_1_12
:
torch_device
=
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
torch_device
...
...
tests/test_pipelines_flax.py
0 → 100644
View file @
f1d4289b
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
,
slow
if
is_flax_available
():
import
jax
from
diffusers
import
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
jax
import
pmap
@
require_flax
@
slow
class
FlaxPipelineTests
(
unittest
.
TestCase
):
def
test_dummy_all_tpus
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
4
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
assert
len
(
images_pil
)
==
8
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment