Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
7c226264
Unverified
Commit
7c226264
authored
Oct 13, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 13, 2022
Browse files
Align PT and Flax API - allow loading checkpoint from PyTorch configs (#827)
* up * finish * add more tests * up * up * finish
parent
78db11db
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
31 deletions
+165
-31
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+32
-23
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+31
-4
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+1
-1
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+99
-1
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
7c226264
...
...
@@ -111,6 +111,9 @@ class FlaxDiffusionPipeline(ConfigMixin):
from
diffusers
import
pipelines
for
name
,
module
in
kwargs
.
items
():
if
module
is
None
:
register_dict
=
{
name
:
(
None
,
None
)}
else
:
# retrieve library
library
=
module
.
__module__
.
split
(
"."
)[
0
]
...
...
@@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
pipeline_class
=
cls
else
:
diffusers_module
=
importlib
.
import_module
(
cls
.
__module__
.
split
(
"."
)[
0
])
class_name
=
(
config_dict
[
"_class_name"
]
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
else
"Flax"
+
config_dict
[
"_class_name"
]
)
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_class_name"
])
# some modules can be passed directly to the init
...
...
@@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
sub_model_should_be_defined
=
True
# if the model is in a pipeline module, then we load it from the pipeline
if
name
in
passed_class_obj
:
...
...
@@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
f
"
{
passed_class_obj
[
name
]
}
is of type:
{
type
(
passed_class_obj
[
name
])
}
, but should be"
f
"
{
expected_class_obj
}
"
)
elif
passed_class_obj
[
name
]
is
None
:
logger
.
warn
(
f
"You have passed `None` for
{
name
}
to disable its functionality in
{
pipeline_class
}
. Note"
f
" that this might lead to problems when using
{
pipeline_class
}
and is not recommended."
)
sub_model_should_be_defined
=
False
else
:
logger
.
warn
(
f
"You have passed a non-standard module
{
passed_class_obj
[
name
]
}
. We cannot verify whether it"
...
...
@@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model
=
passed_class_obj
[
name
]
elif
is_pipeline_module
:
pipeline_module
=
getattr
(
pipelines
,
library_name
)
if
from_pt
:
class_obj
=
import_flax_or_no_model
(
pipeline_module
,
class_name
)
else
:
class_obj
=
getattr
(
pipeline_module
,
class_name
)
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
importable_classes
.
keys
()}
else
:
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
if
from_pt
:
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
else
:
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
7c226264
...
...
@@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from
...models
import
FlaxAutoencoderKL
,
FlaxUNet2DConditionModel
from
...pipeline_flax_utils
import
FlaxDiffusionPipeline
from
...schedulers
import
FlaxDDIMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
from
...utils
import
logging
from
.
import
FlaxStableDiffusionPipelineOutput
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
FlaxStableDiffusionPipeline
(
FlaxDiffusionPipeline
):
r
"""
Pipeline for text-to-image generation using Stable Diffusion.
...
...
@@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
super
().
__init__
()
self
.
dtype
=
dtype
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabled the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self
.
register_modules
(
vae
=
vae
,
text_encoder
=
text_encoder
,
...
...
@@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
height
,
width
,
guidance_scale
,
latents
,
debug
)
if
self
.
safety_checker
is
not
None
:
safety_params
=
params
[
"safety_checker"
]
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
images
=
np
.
asarray
(
images
).
reshape
(
-
1
,
height
,
width
,
3
)
images
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images
,
safety_params
,
jit
)
images_uint8_casted
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
num_devices
,
batch_size
=
images
.
shape
[:
2
]
images_uint8_casted
=
np
.
asarray
(
images_uint8_casted
).
reshape
(
num_devices
*
batch_size
,
height
,
width
,
3
)
images_uint8_casted
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images_uint8_casted
,
safety_params
,
jit
)
images
=
np
.
asarray
(
images
)
# block images
if
any
(
has_nsfw_concept
):
for
i
,
is_nsfw
in
enumerate
(
has_nsfw_concept
):
images
[
i
]
=
np
.
asarray
(
images_uint8_casted
[
i
])
images
=
images
.
reshape
(
num_devices
,
batch_size
,
height
,
width
,
3
)
else
:
has_nsfw_concept
=
False
if
not
return_dict
:
return
(
images
,
has_nsfw_concept
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
7c226264
...
...
@@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
7c226264
...
...
@@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
7c226264
...
...
@@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
tests/test_pipelines_flax.py
View file @
7c226264
...
...
@@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow
if
is_flax_available
():
import
jax
import
jax.numpy
as
jnp
from
diffusers
import
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
...
...
@@ -34,7 +35,7 @@ if is_flax_available():
class
FlaxPipelineTests
(
unittest
.
TestCase
):
def
test_dummy_all_tpus
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
)
prompt
=
(
...
...
@@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase):
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
64
,
64
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
4.151474
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
49947.875
))
<
1e-2
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
assert
len
(
images_pil
)
==
8
def
test_stable_diffusion_v1_4
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"flax"
,
safety_checker
=
None
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
for
i
,
image
in
enumerate
(
images_pil
):
image
.
save
(
f
"/home/patrick/images/flax-test-
{
i
}
_fp32.png"
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.05652401
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2383808.2
))
<
1e-2
def
test_stable_diffusion_v1_4_bfloat_16
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
,
safety_checker
=
None
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
def
test_stable_diffusion_v1_4_bfloat_16_with_safety
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
pipeline
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
jit
=
True
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment