Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1d512244
Unverified
Commit
1d512244
authored
Oct 13, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 13, 2022
Browse files
[Flax] Complete tests (#828)
parent
7c226264
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
9 deletions
+50
-9
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+50
-9
No files found.
tests/test_pipelines_flax.py
View file @
1d512244
...
@@ -24,7 +24,7 @@ from diffusers.utils.testing_utils import require_flax, slow
...
@@ -24,7 +24,7 @@ from diffusers.utils.testing_utils import require_flax, slow
if
is_flax_available
():
if
is_flax_available
():
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
diffusers
import
FlaxStableDiffusionPipeline
from
diffusers
import
FlaxDDIMScheduler
,
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
jax
import
pmap
from
jax
import
pmap
...
@@ -61,7 +61,7 @@ class FlaxPipelineTests(unittest.TestCase):
...
@@ -61,7 +61,7 @@ class FlaxPipelineTests(unittest.TestCase):
assert
images
.
shape
==
(
8
,
1
,
64
,
64
,
3
)
assert
images
.
shape
==
(
8
,
1
,
64
,
64
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
4.151474
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
4.151474
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
49947.875
))
<
1
e-
2
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
49947.875
))
<
5
e-
1
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
...
@@ -93,13 +93,9 @@ class FlaxPipelineTests(unittest.TestCase):
...
@@ -93,13 +93,9 @@ class FlaxPipelineTests(unittest.TestCase):
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
for
i
,
image
in
enumerate
(
images_pil
):
image
.
save
(
f
"/home/patrick/images/flax-test-
{
i
}
_fp32.png"
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.05652401
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.05652401
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2383808.2
))
<
1
e-
2
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2383808.2
))
<
5
e-
1
def
test_stable_diffusion_v1_4_bfloat_16
(
self
):
def
test_stable_diffusion_v1_4_bfloat_16
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
...
@@ -129,7 +125,7 @@ class FlaxPipelineTests(unittest.TestCase):
...
@@ -129,7 +125,7 @@ class FlaxPipelineTests(unittest.TestCase):
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1
e-
2
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
5
e-
1
def
test_stable_diffusion_v1_4_bfloat_16_with_safety
(
self
):
def
test_stable_diffusion_v1_4_bfloat_16_with_safety
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
...
@@ -157,4 +153,49 @@ class FlaxPipelineTests(unittest.TestCase):
...
@@ -157,4 +153,49 @@ class FlaxPipelineTests(unittest.TestCase):
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
5e-1
def
test_stable_diffusion_v1_4_bfloat_16_ddim
(
self
):
scheduler
=
FlaxDDIMScheduler
(
beta_start
=
0.00085
,
beta_end
=
0.012
,
beta_schedule
=
"scaled_linear"
,
set_alpha_to_one
=
False
,
steps_offset
=
1
,
)
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
,
scheduler
=
scheduler
,
safety_checker
=
None
,
)
scheduler_state
=
scheduler
.
create_state
()
params
[
"scheduler"
]
=
scheduler_state
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.045043945
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2347693.5
))
<
5e-1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment