Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
df91c447
Unverified
Commit
df91c447
authored
Mar 23, 2023
by
YiYi Xu
Committed by
GitHub
Mar 23, 2023
Browse files
Flax controlnet (#2727)
* add contronet flax --------- Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
aa0531fa
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1125 additions
and
2 deletions
+1125
-2
docs/source/en/api/models.mdx
docs/source/en/api/models.mdx
+6
-0
docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
+6
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/controlnet_flax.py
src/diffusers/models/controlnet_flax.py
+383
-0
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+16
-0
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_flax_utils.py
src/diffusers/pipelines/pipeline_flax_utils.py
+15
-2
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+1
-0
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py
...le_diffusion/pipeline_flax_stable_diffusion_controlnet.py
+537
-0
src/diffusers/utils/dummy_flax_and_transformers_objects.py
src/diffusers/utils/dummy_flax_and_transformers_objects.py
+15
-0
src/diffusers/utils/dummy_flax_objects.py
src/diffusers/utils/dummy_flax_objects.py
+15
-0
tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py
...stable_diffusion/test_stable_diffusion_flax_controlnet.py
+127
-0
No files found.
docs/source/en/api/models.mdx
View file @
df91c447
...
@@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
...
@@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## FlaxAutoencoderKL
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
View file @
df91c447
...
@@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
...
@@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
- disable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## FlaxStableDiffusionControlNetPipeline
[[autodoc]] FlaxStableDiffusionControlNetPipeline
- all
- __call__
src/diffusers/__init__.py
View file @
df91c447
...
@@ -188,6 +188,7 @@ try:
...
@@ -188,6 +188,7 @@ try:
except
OptionalDependencyNotAvailable
:
except
OptionalDependencyNotAvailable
:
from
.utils.dummy_flax_objects
import
*
# noqa F403
from
.utils.dummy_flax_objects
import
*
# noqa F403
else
:
else
:
from
.models.controlnet_flax
import
FlaxControlNetModel
from
.models.modeling_flax_utils
import
FlaxModelMixin
from
.models.modeling_flax_utils
import
FlaxModelMixin
from
.models.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.models.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.models.vae_flax
import
FlaxAutoencoderKL
from
.models.vae_flax
import
FlaxAutoencoderKL
...
@@ -211,6 +212,7 @@ except OptionalDependencyNotAvailable:
...
@@ -211,6 +212,7 @@ except OptionalDependencyNotAvailable:
from
.utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
from
.utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
else
:
else
:
from
.pipelines
import
(
from
.pipelines
import
(
FlaxStableDiffusionControlNetPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionPipeline
,
FlaxStableDiffusionPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
df91c447
...
@@ -30,5 +30,6 @@ if is_torch_available():
...
@@ -30,5 +30,6 @@ if is_torch_available():
from
.vq_model
import
VQModel
from
.vq_model
import
VQModel
if
is_flax_available
():
if
is_flax_available
():
from
.controlnet_flax
import
FlaxControlNetModel
from
.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.vae_flax
import
FlaxAutoencoderKL
from
.vae_flax
import
FlaxAutoencoderKL
src/diffusers/models/controlnet_flax.py
0 → 100644
View file @
df91c447
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Tuple
,
Union
import
flax
import
flax.linen
as
nn
import
jax
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
from
..configuration_utils
import
ConfigMixin
,
flax_register_to_config
from
..utils
import
BaseOutput
from
.embeddings_flax
import
FlaxTimestepEmbedding
,
FlaxTimesteps
from
.modeling_flax_utils
import
FlaxModelMixin
from
.unet_2d_blocks_flax
import
(
FlaxCrossAttnDownBlock2D
,
FlaxDownBlock2D
,
FlaxUNetMidBlock2DCrossAttn
,
)
@
flax
.
struct
.
dataclass
class
FlaxControlNetOutput
(
BaseOutput
):
down_block_res_samples
:
jnp
.
ndarray
mid_block_res_sample
:
jnp
.
ndarray
class
FlaxControlNetConditioningEmbedding
(
nn
.
Module
):
conditioning_embedding_channels
:
int
block_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
conv_in
=
nn
.
Conv
(
self
.
block_out_channels
[
0
],
kernel_size
=
(
3
,
3
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
blocks
=
[]
for
i
in
range
(
len
(
self
.
block_out_channels
)
-
1
):
channel_in
=
self
.
block_out_channels
[
i
]
channel_out
=
self
.
block_out_channels
[
i
+
1
]
conv1
=
nn
.
Conv
(
channel_in
,
kernel_size
=
(
3
,
3
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
blocks
.
append
(
conv1
)
conv2
=
nn
.
Conv
(
channel_out
,
kernel_size
=
(
3
,
3
),
strides
=
(
2
,
2
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
blocks
.
append
(
conv2
)
self
.
blocks
=
blocks
self
.
conv_out
=
nn
.
Conv
(
self
.
conditioning_embedding_channels
,
kernel_size
=
(
3
,
3
),
padding
=
((
1
,
1
),
(
1
,
1
)),
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
conditioning
):
embedding
=
self
.
conv_in
(
conditioning
)
embedding
=
nn
.
silu
(
embedding
)
for
block
in
self
.
blocks
:
embedding
=
block
(
embedding
)
embedding
=
nn
.
silu
(
embedding
)
embedding
=
self
.
conv_out
(
embedding
)
return
embedding
@
flax_register_to_config
class
FlaxControlNetModel
(
nn
.
Module
,
FlaxModelMixin
,
ConfigMixin
):
r
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
sample_size (`int`, *optional*):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
The channel order of conditional image. Will convert it to `rgb` if it's `bgr`
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in conditioning_embedding layer
"""
sample_size
:
int
=
32
in_channels
:
int
=
4
down_block_types
:
Tuple
[
str
]
=
(
"CrossAttnDownBlock2D"
,
"CrossAttnDownBlock2D"
,
"CrossAttnDownBlock2D"
,
"DownBlock2D"
,
)
only_cross_attention
:
Union
[
bool
,
Tuple
[
bool
]]
=
False
block_out_channels
:
Tuple
[
int
]
=
(
320
,
640
,
1280
,
1280
)
layers_per_block
:
int
=
2
attention_head_dim
:
Union
[
int
,
Tuple
[
int
]]
=
8
cross_attention_dim
:
int
=
1280
dropout
:
float
=
0.0
use_linear_projection
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
flip_sin_to_cos
:
bool
=
True
freq_shift
:
int
=
0
controlnet_conditioning_channel_order
:
str
=
"rgb"
conditioning_embedding_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
KeyArray
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
timesteps
=
jnp
.
ones
((
1
,),
dtype
=
jnp
.
int32
)
encoder_hidden_states
=
jnp
.
zeros
((
1
,
1
,
self
.
cross_attention_dim
),
dtype
=
jnp
.
float32
)
controlnet_cond_shape
=
(
1
,
3
,
self
.
sample_size
*
8
,
self
.
sample_size
*
8
)
controlnet_cond
=
jnp
.
zeros
(
controlnet_cond_shape
,
dtype
=
jnp
.
float32
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
init
(
rngs
,
sample
,
timesteps
,
encoder_hidden_states
,
controlnet_cond
)[
"params"
]
def
setup
(
self
):
block_out_channels
=
self
.
block_out_channels
time_embed_dim
=
block_out_channels
[
0
]
*
4
# input
self
.
conv_in
=
nn
.
Conv
(
block_out_channels
[
0
],
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
# time
self
.
time_proj
=
FlaxTimesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
=
self
.
flip_sin_to_cos
,
freq_shift
=
self
.
config
.
freq_shift
)
self
.
time_embedding
=
FlaxTimestepEmbedding
(
time_embed_dim
,
dtype
=
self
.
dtype
)
self
.
controlnet_cond_embedding
=
FlaxControlNetConditioningEmbedding
(
conditioning_embedding_channels
=
block_out_channels
[
0
],
block_out_channels
=
self
.
conditioning_embedding_out_channels
,
)
only_cross_attention
=
self
.
only_cross_attention
if
isinstance
(
only_cross_attention
,
bool
):
only_cross_attention
=
(
only_cross_attention
,)
*
len
(
self
.
down_block_types
)
attention_head_dim
=
self
.
attention_head_dim
if
isinstance
(
attention_head_dim
,
int
):
attention_head_dim
=
(
attention_head_dim
,)
*
len
(
self
.
down_block_types
)
# down
down_blocks
=
[]
controlnet_down_blocks
=
[]
output_channel
=
block_out_channels
[
0
]
controlnet_block
=
nn
.
Conv
(
output_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
controlnet_down_blocks
.
append
(
controlnet_block
)
for
i
,
down_block_type
in
enumerate
(
self
.
down_block_types
):
input_channel
=
output_channel
output_channel
=
block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
if
down_block_type
==
"CrossAttnDownBlock2D"
:
down_block
=
FlaxCrossAttnDownBlock2D
(
in_channels
=
input_channel
,
out_channels
=
output_channel
,
dropout
=
self
.
dropout
,
num_layers
=
self
.
layers_per_block
,
attn_num_head_channels
=
attention_head_dim
[
i
],
add_downsample
=
not
is_final_block
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
dtype
=
self
.
dtype
,
)
else
:
down_block
=
FlaxDownBlock2D
(
in_channels
=
input_channel
,
out_channels
=
output_channel
,
dropout
=
self
.
dropout
,
num_layers
=
self
.
layers_per_block
,
add_downsample
=
not
is_final_block
,
dtype
=
self
.
dtype
,
)
down_blocks
.
append
(
down_block
)
for
_
in
range
(
self
.
layers_per_block
):
controlnet_block
=
nn
.
Conv
(
output_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
controlnet_down_blocks
.
append
(
controlnet_block
)
if
not
is_final_block
:
controlnet_block
=
nn
.
Conv
(
output_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
controlnet_down_blocks
.
append
(
controlnet_block
)
self
.
down_blocks
=
down_blocks
self
.
controlnet_down_blocks
=
controlnet_down_blocks
# mid
mid_block_channel
=
block_out_channels
[
-
1
]
self
.
mid_block
=
FlaxUNetMidBlock2DCrossAttn
(
in_channels
=
mid_block_channel
,
dropout
=
self
.
dropout
,
attn_num_head_channels
=
attention_head_dim
[
-
1
],
use_linear_projection
=
self
.
use_linear_projection
,
dtype
=
self
.
dtype
,
)
self
.
controlnet_mid_block
=
nn
.
Conv
(
mid_block_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
sample
,
timesteps
,
encoder_hidden_states
,
controlnet_cond
,
conditioning_scale
:
float
=
1.0
,
return_dict
:
bool
=
True
,
train
:
bool
=
False
,
)
->
Union
[
FlaxControlNetOutput
,
Tuple
]:
r
"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale: (`float`) the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
channel_order
=
self
.
controlnet_conditioning_channel_order
if
channel_order
==
"bgr"
:
controlnet_cond
=
jnp
.
flip
(
controlnet_cond
,
axis
=
1
)
# 1. time
if
not
isinstance
(
timesteps
,
jnp
.
ndarray
):
timesteps
=
jnp
.
array
([
timesteps
],
dtype
=
jnp
.
int32
)
elif
isinstance
(
timesteps
,
jnp
.
ndarray
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
.
astype
(
dtype
=
jnp
.
float32
)
timesteps
=
jnp
.
expand_dims
(
timesteps
,
0
)
t_emb
=
self
.
time_proj
(
timesteps
)
t_emb
=
self
.
time_embedding
(
t_emb
)
# 2. pre-process
sample
=
jnp
.
transpose
(
sample
,
(
0
,
2
,
3
,
1
))
sample
=
self
.
conv_in
(
sample
)
controlnet_cond
=
jnp
.
transpose
(
controlnet_cond
,
(
0
,
2
,
3
,
1
))
controlnet_cond
=
self
.
controlnet_cond_embedding
(
controlnet_cond
)
sample
+=
controlnet_cond
# 3. down
down_block_res_samples
=
(
sample
,)
for
down_block
in
self
.
down_blocks
:
if
isinstance
(
down_block
,
FlaxCrossAttnDownBlock2D
):
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
else
:
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
deterministic
=
not
train
)
down_block_res_samples
+=
res_samples
# 4. mid
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
# 5. contronet blocks
controlnet_down_block_res_samples
=
()
for
down_block_res_sample
,
controlnet_block
in
zip
(
down_block_res_samples
,
self
.
controlnet_down_blocks
):
down_block_res_sample
=
controlnet_block
(
down_block_res_sample
)
controlnet_down_block_res_samples
+=
(
down_block_res_sample
,)
down_block_res_samples
=
controlnet_down_block_res_samples
mid_block_res_sample
=
self
.
controlnet_mid_block
(
sample
)
# 6. scaling
down_block_res_samples
=
[
sample
*
conditioning_scale
for
sample
in
down_block_res_samples
]
mid_block_res_sample
*=
conditioning_scale
if
not
return_dict
:
return
(
down_block_res_samples
,
mid_block_res_sample
)
return
FlaxControlNetOutput
(
down_block_res_samples
=
down_block_res_samples
,
mid_block_res_sample
=
mid_block_res_sample
)
src/diffusers/models/unet_2d_condition_flax.py
View file @
df91c447
...
@@ -249,6 +249,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -249,6 +249,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample
,
sample
,
timesteps
,
timesteps
,
encoder_hidden_states
,
encoder_hidden_states
,
down_block_additional_residuals
=
None
,
mid_block_additional_residual
=
None
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
train
:
bool
=
False
,
train
:
bool
=
False
,
)
->
Union
[
FlaxUNet2DConditionOutput
,
Tuple
]:
)
->
Union
[
FlaxUNet2DConditionOutput
,
Tuple
]:
...
@@ -291,9 +293,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -291,9 +293,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
deterministic
=
not
train
)
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
deterministic
=
not
train
)
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
if
down_block_additional_residuals
is
not
None
:
new_down_block_res_samples
=
()
for
down_block_res_sample
,
down_block_additional_residual
in
zip
(
down_block_res_samples
,
down_block_additional_residuals
):
down_block_res_sample
+=
down_block_additional_residual
new_down_block_res_samples
+=
(
down_block_res_sample
,)
down_block_res_samples
=
new_down_block_res_samples
# 4. mid
# 4. mid
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
if
mid_block_additional_residual
is
not
None
:
sample
+=
mid_block_additional_residual
# 5. up
# 5. up
for
up_block
in
self
.
up_blocks
:
for
up_block
in
self
.
up_blocks
:
res_samples
=
down_block_res_samples
[
-
(
self
.
layers_per_block
+
1
)
:]
res_samples
=
down_block_res_samples
[
-
(
self
.
layers_per_block
+
1
)
:]
...
...
src/diffusers/pipelines/__init__.py
View file @
df91c447
...
@@ -124,6 +124,7 @@ except OptionalDependencyNotAvailable:
...
@@ -124,6 +124,7 @@ except OptionalDependencyNotAvailable:
from
..utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
from
..utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
else
:
else
:
from
.stable_diffusion
import
(
from
.stable_diffusion
import
(
FlaxStableDiffusionControlNetPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionPipeline
,
FlaxStableDiffusionPipeline
,
...
...
src/diffusers/pipelines/pipeline_flax_utils.py
View file @
df91c447
...
@@ -278,7 +278,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -278,7 +278,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>>
sched, sched
_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
>>>
dpmpp, dpmpp
_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... model_id,
... subfolder="scheduler",
... subfolder="scheduler",
... )
... )
...
@@ -365,7 +365,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -365,7 +365,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# in this case they are already instantiated in `kwargs`
# extract them here
# extract them here
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
()
)
expected_modules
,
optional_kwargs
=
cls
.
_get_
signature
_keys
(
pipeline_class
)
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
init_dict
,
_
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
_
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
@@ -470,6 +470,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -470,6 +470,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
# 4. Potentially add passed objects if expected
missing_modules
=
set
(
expected_modules
)
-
set
(
init_kwargs
.
keys
())
passed_modules
=
list
(
passed_class_obj
.
keys
())
if
len
(
missing_modules
)
>
0
and
missing_modules
<=
set
(
passed_modules
):
for
module
in
missing_modules
:
init_kwargs
[
module
]
=
passed_class_obj
.
get
(
module
,
None
)
elif
len
(
missing_modules
)
>
0
:
passed_modules
=
set
(
list
(
init_kwargs
.
keys
())
+
list
(
passed_class_obj
.
keys
()))
-
optional_kwargs
raise
ValueError
(
f
"Pipeline
{
pipeline_class
}
expected
{
expected_modules
}
, but only
{
passed_modules
}
were passed."
)
model
=
pipeline_class
(
**
init_kwargs
,
dtype
=
dtype
)
model
=
pipeline_class
(
**
init_kwargs
,
dtype
=
dtype
)
return
model
,
params
return
model
,
params
...
...
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
df91c447
...
@@ -127,6 +127,7 @@ if is_transformers_available() and is_flax_available():
...
@@ -127,6 +127,7 @@ if is_transformers_available() and is_flax_available():
from
...schedulers.scheduling_pndm_flax
import
PNDMSchedulerState
from
...schedulers.scheduling_pndm_flax
import
PNDMSchedulerState
from
.pipeline_flax_stable_diffusion
import
FlaxStableDiffusionPipeline
from
.pipeline_flax_stable_diffusion
import
FlaxStableDiffusionPipeline
from
.pipeline_flax_stable_diffusion_controlnet
import
FlaxStableDiffusionControlNetPipeline
from
.pipeline_flax_stable_diffusion_img2img
import
FlaxStableDiffusionImg2ImgPipeline
from
.pipeline_flax_stable_diffusion_img2img
import
FlaxStableDiffusionImg2ImgPipeline
from
.pipeline_flax_stable_diffusion_inpaint
import
FlaxStableDiffusionInpaintPipeline
from
.pipeline_flax_stable_diffusion_inpaint
import
FlaxStableDiffusionInpaintPipeline
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py
0 → 100644
View file @
df91c447
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
functools
import
partial
from
typing
import
Dict
,
List
,
Optional
,
Union
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
flax.core.frozen_dict
import
FrozenDict
from
flax.jax_utils
import
unreplicate
from
flax.training.common_utils
import
shard
from
PIL
import
Image
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
,
FlaxCLIPTextModel
from
...models
import
FlaxAutoencoderKL
,
FlaxControlNetModel
,
FlaxUNet2DConditionModel
from
...schedulers
import
(
FlaxDDIMScheduler
,
FlaxDPMSolverMultistepScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
,
)
from
...utils
import
PIL_INTERPOLATION
,
logging
,
replace_example_docstring
from
..pipeline_flax_utils
import
FlaxDiffusionPipeline
from
.
import
FlaxStableDiffusionPipelineOutput
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG
=
False
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
>>> import jax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> from diffusers.utils import load_image
>>> from PIL import Image
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
>>> def image_grid(imgs, rows, cols):
... w, h = imgs[0].size
... grid = Image.new("RGB", size=(cols * w, rows * h))
... for i, img in enumerate(imgs):
... grid.paste(img, box=(i % cols * w, i // cols * h))
... return grid
>>> def create_key(seed=0):
... return jax.random.PRNGKey(seed)
>>> rng = create_key(0)
>>> # get canny image
>>> canny_image = load_image(
... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
... )
>>> prompts = "best quality, extremely detailed"
>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
>>> # load control net and stable diffusion v1-5
>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
... )
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
... )
>>> params["controlnet"] = controlnet_params
>>> num_samples = jax.device_count()
>>> rng = jax.random.split(rng, jax.device_count())
>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
>>> p_params = replicate(params)
>>> prompt_ids = shard(prompt_ids)
>>> negative_prompt_ids = shard(negative_prompt_ids)
>>> processed_image = shard(processed_image)
>>> output = pipe(
... prompt_ids=prompt_ids,
... image=processed_image,
... params=p_params,
... prng_seed=rng,
... num_inference_steps=50,
... neg_prompt_ids=negative_prompt_ids,
... jit=True,
... ).images
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
>>> output_images = image_grid(output_images, num_samples // 4, 4)
>>> output_images.save("generated_image.png")
```
"""
class
FlaxStableDiffusionControlNetPipeline
(
FlaxDiffusionPipeline
):
r
"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`FlaxCLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`FlaxControlNetModel`]:
Provides additional conditioning to the unet during the denoising process.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def
__init__
(
self
,
vae
:
FlaxAutoencoderKL
,
text_encoder
:
FlaxCLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
FlaxUNet2DConditionModel
,
controlnet
:
FlaxControlNetModel
,
scheduler
:
Union
[
FlaxDDIMScheduler
,
FlaxPNDMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxDPMSolverMultistepScheduler
],
safety_checker
:
FlaxStableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
):
super
().
__init__
()
self
.
dtype
=
dtype
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabled the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self
.
register_modules
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
controlnet
=
controlnet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
def
prepare_text_inputs
(
self
,
prompt
:
Union
[
str
,
List
[
str
]]):
if
not
isinstance
(
prompt
,
(
str
,
list
)):
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
)
return
text_input
.
input_ids
def
prepare_image_inputs
(
self
,
image
:
Union
[
Image
.
Image
,
List
[
Image
.
Image
]]):
if
not
isinstance
(
image
,
(
Image
.
Image
,
list
)):
raise
ValueError
(
f
"image has to be of type `PIL.Image.Image` or list but is
{
type
(
image
)
}
"
)
if
isinstance
(
image
,
Image
.
Image
):
image
=
[
image
]
processed_images
=
jnp
.
concatenate
([
preprocess
(
img
,
jnp
.
float32
)
for
img
in
image
])
return
processed_images
def
_get_has_nsfw_concepts
(
self
,
features
,
params
):
has_nsfw_concepts
=
self
.
safety_checker
(
features
,
params
)
return
has_nsfw_concepts
def
_run_safety_checker
(
self
,
images
,
safety_model_params
,
jit
=
False
):
# safety_model_params should already be replicated when jit is True
pil_images
=
[
Image
.
fromarray
(
image
)
for
image
in
images
]
features
=
self
.
feature_extractor
(
pil_images
,
return_tensors
=
"np"
).
pixel_values
if
jit
:
features
=
shard
(
features
)
has_nsfw_concepts
=
_p_get_has_nsfw_concepts
(
self
,
features
,
safety_model_params
)
has_nsfw_concepts
=
unshard
(
has_nsfw_concepts
)
safety_model_params
=
unreplicate
(
safety_model_params
)
else
:
has_nsfw_concepts
=
self
.
_get_has_nsfw_concepts
(
features
,
safety_model_params
)
images_was_copied
=
False
for
idx
,
has_nsfw_concept
in
enumerate
(
has_nsfw_concepts
):
if
has_nsfw_concept
:
if
not
images_was_copied
:
images_was_copied
=
True
images
=
images
.
copy
()
images
[
idx
]
=
np
.
zeros
(
images
[
idx
].
shape
,
dtype
=
np
.
uint8
)
# black image
if
any
(
has_nsfw_concepts
):
warnings
.
warn
(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return
images
,
has_nsfw_concepts
def
_generate
(
self
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
KeyArray
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
latents
:
Optional
[
jnp
.
array
]
=
None
,
neg_prompt_ids
:
Optional
[
jnp
.
array
]
=
None
,
controlnet_conditioning_scale
:
float
=
1.0
,
):
height
,
width
=
image
.
shape
[
-
2
:]
if
height
%
64
!=
0
or
width
%
64
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 64 but are
{
height
}
and
{
width
}
."
)
# get prompt text embeddings
prompt_embeds
=
self
.
text_encoder
(
prompt_ids
,
params
=
params
[
"text_encoder"
])[
0
]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size
=
prompt_ids
.
shape
[
0
]
max_length
=
prompt_ids
.
shape
[
-
1
]
if
neg_prompt_ids
is
None
:
uncond_input
=
self
.
tokenizer
(
[
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"np"
).
input_ids
else
:
uncond_input
=
neg_prompt_ids
negative_prompt_embeds
=
self
.
text_encoder
(
uncond_input
,
params
=
params
[
"text_encoder"
])[
0
]
context
=
jnp
.
concatenate
([
negative_prompt_embeds
,
prompt_embeds
])
image
=
jnp
.
concatenate
([
image
]
*
2
)
latents_shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
if
latents
is
None
:
latents
=
jax
.
random
.
normal
(
prng_seed
,
shape
=
latents_shape
,
dtype
=
jnp
.
float32
)
else
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
def
loop_body
(
step
,
args
):
latents
,
scheduler_state
=
args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input
=
jnp
.
concatenate
([
latents
]
*
2
)
t
=
jnp
.
array
(
scheduler_state
.
timesteps
,
dtype
=
jnp
.
int32
)[
step
]
timestep
=
jnp
.
broadcast_to
(
t
,
latents_input
.
shape
[
0
])
latents_input
=
self
.
scheduler
.
scale_model_input
(
scheduler_state
,
latents_input
,
t
)
down_block_res_samples
,
mid_block_res_sample
=
self
.
controlnet
.
apply
(
{
"params"
:
params
[
"controlnet"
]},
jnp
.
array
(
latents_input
),
jnp
.
array
(
timestep
,
dtype
=
jnp
.
int32
),
encoder_hidden_states
=
context
,
controlnet_cond
=
image
,
conditioning_scale
=
controlnet_conditioning_scale
,
return_dict
=
False
,
)
# predict the noise residual
noise_pred
=
self
.
unet
.
apply
(
{
"params"
:
params
[
"unet"
]},
jnp
.
array
(
latents_input
),
jnp
.
array
(
timestep
,
dtype
=
jnp
.
int32
),
encoder_hidden_states
=
context
,
down_block_additional_residuals
=
down_block_res_samples
,
mid_block_additional_residual
=
mid_block_res_sample
,
).
sample
# perform guidance
noise_pred_uncond
,
noise_prediction_text
=
jnp
.
split
(
noise_pred
,
2
,
axis
=
0
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_prediction_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
,
scheduler_state
=
self
.
scheduler
.
step
(
scheduler_state
,
noise_pred
,
t
,
latents
).
to_tuple
()
return
latents
,
scheduler_state
scheduler_state
=
self
.
scheduler
.
set_timesteps
(
params
[
"scheduler"
],
num_inference_steps
=
num_inference_steps
,
shape
=
latents_shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
params
[
"scheduler"
].
init_noise_sigma
if
DEBUG
:
# run with python for loop
for
i
in
range
(
num_inference_steps
):
latents
,
scheduler_state
=
loop_body
(
i
,
(
latents
,
scheduler_state
))
else
:
latents
,
_
=
jax
.
lax
.
fori_loop
(
0
,
num_inference_steps
,
loop_body
,
(
latents
,
scheduler_state
))
# scale and decode the image latents with vae
latents
=
1
/
self
.
vae
.
config
.
scaling_factor
*
latents
image
=
self
.
vae
.
apply
({
"params"
:
params
[
"vae"
]},
latents
,
method
=
self
.
vae
.
decode
).
sample
image
=
(
image
/
2
+
0.5
).
clip
(
0
,
1
).
transpose
(
0
,
2
,
3
,
1
)
return
image
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
self
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
KeyArray
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jnp
.
array
]
=
7.5
,
latents
:
jnp
.
array
=
None
,
neg_prompt_ids
:
jnp
.
array
=
None
,
controlnet_conditioning_scale
:
Union
[
float
,
jnp
.
array
]
=
1.0
,
return_dict
:
bool
=
True
,
jit
:
bool
=
False
,
):
r
"""
Function invoked when calling the pipeline for generation.
Args:
prompt_ids (`jnp.array`):
The prompt or prompts to guide the image generation.
image (`jnp.array`):
Array representing the ControlNet input condition. ControlNet use this input condition to generate
guidance to Unet.
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
Examples:
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
height
,
width
=
image
.
shape
[
-
2
:]
if
isinstance
(
guidance_scale
,
float
):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale
=
jnp
.
array
([
guidance_scale
]
*
prompt_ids
.
shape
[
0
])
if
len
(
prompt_ids
.
shape
)
>
2
:
# Assume sharded
guidance_scale
=
guidance_scale
[:,
None
]
if
isinstance
(
controlnet_conditioning_scale
,
float
):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
controlnet_conditioning_scale
=
jnp
.
array
([
controlnet_conditioning_scale
]
*
prompt_ids
.
shape
[
0
])
if
len
(
prompt_ids
.
shape
)
>
2
:
# Assume sharded
controlnet_conditioning_scale
=
controlnet_conditioning_scale
[:,
None
]
if
jit
:
images
=
_p_generate
(
self
,
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
)
else
:
images
=
self
.
_generate
(
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
)
if
self
.
safety_checker
is
not
None
:
safety_params
=
params
[
"safety_checker"
]
images_uint8_casted
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
num_devices
,
batch_size
=
images
.
shape
[:
2
]
images_uint8_casted
=
np
.
asarray
(
images_uint8_casted
).
reshape
(
num_devices
*
batch_size
,
height
,
width
,
3
)
images_uint8_casted
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images_uint8_casted
,
safety_params
,
jit
)
images
=
np
.
asarray
(
images
)
# block images
if
any
(
has_nsfw_concept
):
for
i
,
is_nsfw
in
enumerate
(
has_nsfw_concept
):
if
is_nsfw
:
images
[
i
]
=
np
.
asarray
(
images_uint8_casted
[
i
])
images
=
images
.
reshape
(
num_devices
,
batch_size
,
height
,
width
,
3
)
else
:
images
=
np
.
asarray
(
images
)
has_nsfw_concept
=
False
if
not
return_dict
:
return
(
images
,
has_nsfw_concept
)
return
FlaxStableDiffusionPipelineOutput
(
images
=
images
,
nsfw_content_detected
=
has_nsfw_concept
)
# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@
partial
(
jax
.
pmap
,
in_axes
=
(
None
,
0
,
0
,
0
,
0
,
None
,
0
,
0
,
0
,
0
),
static_broadcasted_argnums
=
(
0
,
5
),
)
def
_p_generate
(
pipe
,
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
):
return
pipe
.
_generate
(
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
)
@
partial
(
jax
.
pmap
,
static_broadcasted_argnums
=
(
0
,))
def
_p_get_has_nsfw_concepts
(
pipe
,
features
,
params
):
return
pipe
.
_get_has_nsfw_concepts
(
features
,
params
)
def
unshard
(
x
:
jnp
.
ndarray
):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices
,
batch_size
=
x
.
shape
[:
2
]
rest
=
x
.
shape
[
2
:]
return
x
.
reshape
(
num_devices
*
batch_size
,
*
rest
)
def
preprocess
(
image
,
dtype
):
image
=
image
.
convert
(
"RGB"
)
w
,
h
=
image
.
size
w
,
h
=
map
(
lambda
x
:
x
-
x
%
64
,
(
w
,
h
))
# resize to integer multiple of 64
image
=
image
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
"lanczos"
])
image
=
jnp
.
array
(
image
).
astype
(
dtype
)
/
255.0
image
=
image
[
None
].
transpose
(
0
,
3
,
1
,
2
)
return
image
src/diffusers/utils/dummy_flax_and_transformers_objects.py
View file @
df91c447
...
@@ -2,6 +2,21 @@
...
@@ -2,6 +2,21 @@
from
..utils
import
DummyObject
,
requires_backends
from
..utils
import
DummyObject
,
requires_backends
class
FlaxStableDiffusionControlNetPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
,
"transformers"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
,
"transformers"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
,
"transformers"
])
class
FlaxStableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
class
FlaxStableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
,
"transformers"
]
_backends
=
[
"flax"
,
"transformers"
]
...
...
src/diffusers/utils/dummy_flax_objects.py
View file @
df91c447
...
@@ -2,6 +2,21 @@
...
@@ -2,6 +2,21 @@
from
..utils
import
DummyObject
,
requires_backends
from
..utils
import
DummyObject
,
requires_backends
class
FlaxControlNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
])
class
FlaxModelMixin
(
metaclass
=
DummyObject
):
class
FlaxModelMixin
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
_backends
=
[
"flax"
]
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py
0 → 100644
View file @
df91c447
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
from
diffusers
import
FlaxControlNetModel
,
FlaxStableDiffusionControlNetPipeline
from
diffusers.utils
import
is_flax_available
,
load_image
,
slow
from
diffusers.utils.testing_utils
import
require_flax
if
is_flax_available
():
import
jax
import
jax.numpy
as
jnp
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
@
slow
@
require_flax
class
FlaxStableDiffusionControlNetPipelineIntegrationTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
gc
.
collect
()
def
test_canny
(
self
):
controlnet
,
controlnet_params
=
FlaxControlNetModel
.
from_pretrained
(
"lllyasviel/sd-controlnet-canny"
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
pipe
,
params
=
FlaxStableDiffusionControlNetPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
controlnet
=
controlnet
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
params
[
"controlnet"
]
=
controlnet_params
prompts
=
"bird"
num_samples
=
jax
.
device_count
()
prompt_ids
=
pipe
.
prepare_text_inputs
([
prompts
]
*
num_samples
)
canny_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
processed_image
=
pipe
.
prepare_image_inputs
([
canny_image
]
*
num_samples
)
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
=
jax
.
random
.
split
(
rng
,
jax
.
device_count
())
p_params
=
replicate
(
params
)
prompt_ids
=
shard
(
prompt_ids
)
processed_image
=
shard
(
processed_image
)
images
=
pipe
(
prompt_ids
=
prompt_ids
,
image
=
processed_image
,
params
=
p_params
,
prng_seed
=
rng
,
num_inference_steps
=
50
,
jit
=
True
,
).
images
assert
images
.
shape
==
(
jax
.
device_count
(),
1
,
768
,
512
,
3
)
images
=
images
.
reshape
((
images
.
shape
[
0
]
*
images
.
shape
[
1
],)
+
images
.
shape
[
-
3
:])
image_slice
=
images
[
0
,
253
:
256
,
253
:
256
,
-
1
]
output_slice
=
jnp
.
asarray
(
jax
.
device_get
(
image_slice
.
flatten
()))
expected_slice
=
jnp
.
array
(
[
0.167969
,
0.116699
,
0.081543
,
0.154297
,
0.132812
,
0.108887
,
0.169922
,
0.169922
,
0.205078
]
)
print
(
f
"output_slice:
{
output_slice
}
"
)
assert
jnp
.
abs
(
output_slice
-
expected_slice
).
max
()
<
1e-2
def
test_pose
(
self
):
controlnet
,
controlnet_params
=
FlaxControlNetModel
.
from_pretrained
(
"lllyasviel/sd-controlnet-openpose"
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
pipe
,
params
=
FlaxStableDiffusionControlNetPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
controlnet
=
controlnet
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
params
[
"controlnet"
]
=
controlnet_params
prompts
=
"Chef in the kitchen"
num_samples
=
jax
.
device_count
()
prompt_ids
=
pipe
.
prepare_text_inputs
([
prompts
]
*
num_samples
)
pose_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
)
processed_image
=
pipe
.
prepare_image_inputs
([
pose_image
]
*
num_samples
)
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
=
jax
.
random
.
split
(
rng
,
jax
.
device_count
())
p_params
=
replicate
(
params
)
prompt_ids
=
shard
(
prompt_ids
)
processed_image
=
shard
(
processed_image
)
images
=
pipe
(
prompt_ids
=
prompt_ids
,
image
=
processed_image
,
params
=
p_params
,
prng_seed
=
rng
,
num_inference_steps
=
50
,
jit
=
True
,
).
images
assert
images
.
shape
==
(
jax
.
device_count
(),
1
,
768
,
512
,
3
)
images
=
images
.
reshape
((
images
.
shape
[
0
]
*
images
.
shape
[
1
],)
+
images
.
shape
[
-
3
:])
image_slice
=
images
[
0
,
253
:
256
,
253
:
256
,
-
1
]
output_slice
=
jnp
.
asarray
(
jax
.
device_get
(
image_slice
.
flatten
()))
expected_slice
=
jnp
.
array
(
[[
0.271484
,
0.261719
,
0.275391
,
0.277344
,
0.279297
,
0.291016
,
0.294922
,
0.302734
,
0.302734
]]
)
print
(
f
"output_slice:
{
output_slice
}
"
)
assert
jnp
.
abs
(
output_slice
-
expected_slice
).
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment