Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7c226264
Unverified
Commit
7c226264
authored
Oct 13, 2022
by
Patrick von Platen
Committed by
GitHub
Oct 13, 2022
Browse files
Align PT and Flax API - allow loading checkpoint from PyTorch configs (#827)
* up * finish * add more tests * up * up * finish
parent
78db11db
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
31 deletions
+165
-31
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+32
-23
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
...elines/stable_diffusion/pipeline_flax_stable_diffusion.py
+31
-4
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+1
-1
tests/test_pipelines_flax.py
tests/test_pipelines_flax.py
+99
-1
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
7c226264
...
@@ -111,6 +111,9 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -111,6 +111,9 @@ class FlaxDiffusionPipeline(ConfigMixin):
from
diffusers
import
pipelines
from
diffusers
import
pipelines
for
name
,
module
in
kwargs
.
items
():
for
name
,
module
in
kwargs
.
items
():
if
module
is
None
:
register_dict
=
{
name
:
(
None
,
None
)}
else
:
# retrieve library
# retrieve library
library
=
module
.
__module__
.
split
(
"."
)[
0
]
library
=
module
.
__module__
.
split
(
"."
)[
0
]
...
@@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
pipeline_class
=
cls
pipeline_class
=
cls
else
:
else
:
diffusers_module
=
importlib
.
import_module
(
cls
.
__module__
.
split
(
"."
)[
0
])
diffusers_module
=
importlib
.
import_module
(
cls
.
__module__
.
split
(
"."
)[
0
])
class_name
=
(
config_dict
[
"_class_name"
]
if
config_dict
[
"_class_name"
].
startswith
(
"Flax"
)
else
"Flax"
+
config_dict
[
"_class_name"
]
)
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_class_name"
])
pipeline_class
=
getattr
(
diffusers_module
,
config_dict
[
"_class_name"
])
# some modules can be passed directly to the init
# some modules can be passed directly to the init
...
@@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
loaded_sub_model
=
None
loaded_sub_model
=
None
sub_model_should_be_defined
=
True
# if the model is in a pipeline module, then we load it from the pipeline
# if the model is in a pipeline module, then we load it from the pipeline
if
name
in
passed_class_obj
:
if
name
in
passed_class_obj
:
...
@@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
f
"
{
passed_class_obj
[
name
]
}
is of type:
{
type
(
passed_class_obj
[
name
])
}
, but should be"
f
"
{
passed_class_obj
[
name
]
}
is of type:
{
type
(
passed_class_obj
[
name
])
}
, but should be"
f
"
{
expected_class_obj
}
"
f
"
{
expected_class_obj
}
"
)
)
elif
passed_class_obj
[
name
]
is
None
:
logger
.
warn
(
f
"You have passed `None` for
{
name
}
to disable its functionality in
{
pipeline_class
}
. Note"
f
" that this might lead to problems when using
{
pipeline_class
}
and is not recommended."
)
sub_model_should_be_defined
=
False
else
:
else
:
logger
.
warn
(
logger
.
warn
(
f
"You have passed a non-standard module
{
passed_class_obj
[
name
]
}
. We cannot verify whether it"
f
"You have passed a non-standard module
{
passed_class_obj
[
name
]
}
. We cannot verify whether it"
...
@@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model
=
passed_class_obj
[
name
]
loaded_sub_model
=
passed_class_obj
[
name
]
elif
is_pipeline_module
:
elif
is_pipeline_module
:
pipeline_module
=
getattr
(
pipelines
,
library_name
)
pipeline_module
=
getattr
(
pipelines
,
library_name
)
if
from_pt
:
class_obj
=
import_flax_or_no_model
(
pipeline_module
,
class_name
)
class_obj
=
import_flax_or_no_model
(
pipeline_module
,
class_name
)
else
:
class_obj
=
getattr
(
pipeline_module
,
class_name
)
importable_classes
=
ALL_IMPORTABLE_CLASSES
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
class_obj
for
c
in
importable_classes
.
keys
()}
else
:
else
:
# else we just import it from the library.
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
if
from_pt
:
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
else
:
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
issubclass
(
class_obj
,
class_candidate
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
View file @
7c226264
...
@@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
...
@@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from
...models
import
FlaxAutoencoderKL
,
FlaxUNet2DConditionModel
from
...models
import
FlaxAutoencoderKL
,
FlaxUNet2DConditionModel
from
...pipeline_flax_utils
import
FlaxDiffusionPipeline
from
...pipeline_flax_utils
import
FlaxDiffusionPipeline
from
...schedulers
import
FlaxDDIMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
from
...schedulers
import
FlaxDDIMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
from
...utils
import
logging
from
.
import
FlaxStableDiffusionPipelineOutput
from
.
import
FlaxStableDiffusionPipelineOutput
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
class
FlaxStableDiffusionPipeline
(
FlaxDiffusionPipeline
):
class
FlaxStableDiffusionPipeline
(
FlaxDiffusionPipeline
):
r
"""
r
"""
Pipeline for text-to-image generation using Stable Diffusion.
Pipeline for text-to-image generation using Stable Diffusion.
...
@@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
super
().
__init__
()
super
().
__init__
()
self
.
dtype
=
dtype
self
.
dtype
=
dtype
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabled the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self
.
register_modules
(
self
.
register_modules
(
vae
=
vae
,
vae
=
vae
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
...
@@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
...
@@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
height
,
width
,
guidance_scale
,
latents
,
debug
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
height
,
width
,
guidance_scale
,
latents
,
debug
)
)
if
self
.
safety_checker
is
not
None
:
safety_params
=
params
[
"safety_checker"
]
safety_params
=
params
[
"safety_checker"
]
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
images_uint8_casted
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
images
=
np
.
asarray
(
images
).
reshape
(
-
1
,
height
,
width
,
3
)
num_devices
,
batch_size
=
images
.
shape
[:
2
]
images
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images
,
safety_params
,
jit
)
images_uint8_casted
=
np
.
asarray
(
images_uint8_casted
).
reshape
(
num_devices
*
batch_size
,
height
,
width
,
3
)
images_uint8_casted
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images_uint8_casted
,
safety_params
,
jit
)
images
=
np
.
asarray
(
images
)
# block images
if
any
(
has_nsfw_concept
):
for
i
,
is_nsfw
in
enumerate
(
has_nsfw_concept
):
images
[
i
]
=
np
.
asarray
(
images_uint8_casted
[
i
])
images
=
images
.
reshape
(
num_devices
,
batch_size
,
height
,
width
,
3
)
else
:
has_nsfw_concept
=
False
if
not
return_dict
:
if
not
return_dict
:
return
(
images
,
has_nsfw_concept
)
return
(
images
,
has_nsfw_concept
)
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
7c226264
...
@@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
...
@@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
if
safety_checker
is
None
:
logger
.
warn
(
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
7c226264
...
@@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
if
safety_checker
is
None
:
logger
.
warn
(
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
7c226264
...
@@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
...
@@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if
safety_checker
is
None
:
if
safety_checker
is
None
:
logger
.
warn
(
logger
.
warn
(
f
"You have disabed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
f
"You have disab
l
ed the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
...
...
tests/test_pipelines_flax.py
View file @
7c226264
...
@@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow
...
@@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow
if
is_flax_available
():
if
is_flax_available
():
import
jax
import
jax
import
jax.numpy
as
jnp
from
diffusers
import
FlaxStableDiffusionPipeline
from
diffusers
import
FlaxStableDiffusionPipeline
from
flax.jax_utils
import
replicate
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
...
@@ -34,7 +35,7 @@ if is_flax_available():
...
@@ -34,7 +35,7 @@ if is_flax_available():
class
FlaxPipelineTests
(
unittest
.
TestCase
):
class
FlaxPipelineTests
(
unittest
.
TestCase
):
def
test_dummy_all_tpus
(
self
):
def
test_dummy_all_tpus
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
)
)
prompt
=
(
prompt
=
(
...
@@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase):
...
@@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase):
prompt_ids
=
shard
(
prompt_ids
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
64
,
64
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
4.151474
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
49947.875
))
<
1e-2
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
assert
len
(
images_pil
)
==
8
assert
len
(
images_pil
)
==
8
def
test_stable_diffusion_v1_4
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"flax"
,
safety_checker
=
None
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
images_pil
=
pipeline
.
numpy_to_pil
(
np
.
asarray
(
images
.
reshape
((
num_samples
,)
+
images
.
shape
[
-
3
:])))
for
i
,
image
in
enumerate
(
images_pil
):
image
.
save
(
f
"/home/patrick/images/flax-test-
{
i
}
_fp32.png"
)
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.05652401
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2383808.2
))
<
1e-2
def
test_stable_diffusion_v1_4_bfloat_16
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
,
safety_checker
=
None
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
p_sample
=
pmap
(
pipeline
.
__call__
,
static_broadcasted_argnums
=
(
3
,))
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
p_sample
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
def
test_stable_diffusion_v1_4_bfloat_16_with_safety
(
self
):
pipeline
,
params
=
FlaxStableDiffusionPipeline
.
from_pretrained
(
"CompVis/stable-diffusion-v1-4"
,
revision
=
"bf16"
,
dtype
=
jnp
.
bfloat16
)
prompt
=
(
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed
=
jax
.
random
.
PRNGKey
(
0
)
num_inference_steps
=
50
num_samples
=
jax
.
device_count
()
prompt
=
num_samples
*
[
prompt
]
prompt_ids
=
pipeline
.
prepare_inputs
(
prompt
)
# shard inputs and rng
params
=
replicate
(
params
)
prng_seed
=
jax
.
random
.
split
(
prng_seed
,
8
)
prompt_ids
=
shard
(
prompt_ids
)
images
=
pipeline
(
prompt_ids
,
params
,
prng_seed
,
num_inference_steps
,
jit
=
True
).
images
assert
images
.
shape
==
(
8
,
1
,
512
,
512
,
3
)
assert
np
.
abs
((
np
.
abs
(
images
[
0
,
0
,
:
2
,
:
2
,
-
2
:],
dtype
=
np
.
float32
).
sum
()
-
0.06652832
))
<
1e-3
assert
np
.
abs
((
np
.
abs
(
images
,
dtype
=
np
.
float32
).
sum
()
-
2384849.8
))
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment