Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f1d4289b
Unverified
Commit
f1d4289b
authored
Oct 13, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 13, 2022
Browse files
[Flax] Add test (#824)
parent
323a9e1f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
8 deletions
+74
-8
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+0
-1
src/diffusers/utils/testing_utils.py
src/diffusers/utils/testing_utils.py
+12
-7
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+62
-0
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
f1d4289b
...
@@ -52,7 +52,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -52,7 +52,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
):
):
super
().
__init__
()
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"np"
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
register_modules
(
self
.
register_modules
(
...
...
src/diffusers/utils/testing_utils.py
View file @
f1d4289b
...
@@ -7,21 +7,26 @@ from distutils.util import strtobool
...
@@ -7,21 +7,26 @@ from distutils.util import strtobool
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Union
from
typing
import
Union
import
torch
import
PIL.Image
import
PIL.Image
import
PIL.ImageOps
import
PIL.ImageOps
import
requests
import
requests
from
packaging
import
version
from
packaging
import
version
from
.import_utils
import
is_flax_available
from
.import_utils
import
is_flax_available
,
is_torch_available
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
is_torch_higher_equal_than_1_12
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
>=
version
.
parse
(
"1.12"
)
if
is_torch_higher_equal_than_1_12
:
if
is_torch_available
():
import
torch
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
is_torch_higher_equal_than_1_12
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
>=
version
.
parse
(
"1.12"
)
if
is_torch_higher_equal_than_1_12
:
torch_device
=
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
torch_device
torch_device
=
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
torch_device
...
...
tests/test_pipelines_flax.py
0 → 100644
View file @
f1d4289b
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
,
slow
if
is_flax_available
():
import
jax
from
diffusers
import
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
jax
import
pmap
@
require_flax
@
slow
class
FlaxPipelineTests
(
unittest
.
TestCase
):
def
test_dummy_all_tpus
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
4
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
assert
len
(
images_pil
)
==
8
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment