Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
df91c447
Unverified
Commit
df91c447
authored
Mar 23, 2023
by
YiYi Xu
Committed by
GitHub
Mar 23, 2023
Browse files
Flax controlnet (#2727)
* add contronet flax --------- Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
aa0531fa
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1125 additions
and
2 deletions
+1125
-2
docs/source/en/api/models.mdx
docs/source/en/api/models.mdx
+6
-0
docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
+6
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/controlnet_flax.py
src/diffusers/models/controlnet_flax.py
+383
-0
src/diffusers/models/unet_2d_condition_flax.py
src/diffusers/models/unet_2d_condition_flax.py
+16
-0
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_flax_utils.py
src/diffusers/pipelines/pipeline_flax_utils.py
+15
-2
src/diffusers/pipelines/stable_diffusion/__init__.py
src/diffusers/pipelines/stable_diffusion/__init__.py
+1
-0
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py
...le_diffusion/pipeline_flax_stable_diffusion_controlnet.py
+537
-0
src/diffusers/utils/dummy_flax_and_transformers_objects.py
src/diffusers/utils/dummy_flax_and_transformers_objects.py
+15
-0
src/diffusers/utils/dummy_flax_objects.py
src/diffusers/utils/dummy_flax_objects.py
+15
-0
tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py
...stable_diffusion/test_stable_diffusion_flax_controlnet.py
+127
-0
No files found.
docs/source/en/api/models.mdx
View file @
df91c447
...
@@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
...
@@ -99,3 +99,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## FlaxAutoencoderKL
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
## FlaxControlNetOutput
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
## FlaxControlNetModel
[[autodoc]] FlaxControlNetModel
docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx
View file @
df91c447
...
@@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
...
@@ -272,3 +272,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
- disable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## FlaxStableDiffusionControlNetPipeline
[[autodoc]] FlaxStableDiffusionControlNetPipeline
- all
- __call__
src/diffusers/__init__.py
View file @
df91c447
...
@@ -188,6 +188,7 @@ try:
...
@@ -188,6 +188,7 @@ try:
except
OptionalDependencyNotAvailable
:
except
OptionalDependencyNotAvailable
:
from
.utils.dummy_flax_objects
import
*
# noqa F403
from
.utils.dummy_flax_objects
import
*
# noqa F403
else
:
else
:
from
.models.controlnet_flax
import
FlaxControlNetModel
from
.models.modeling_flax_utils
import
FlaxModelMixin
from
.models.modeling_flax_utils
import
FlaxModelMixin
from
.models.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.models.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.models.vae_flax
import
FlaxAutoencoderKL
from
.models.vae_flax
import
FlaxAutoencoderKL
...
@@ -211,6 +212,7 @@ except OptionalDependencyNotAvailable:
...
@@ -211,6 +212,7 @@ except OptionalDependencyNotAvailable:
from
.utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
from
.utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
else
:
else
:
from
.pipelines
import
(
from
.pipelines
import
(
FlaxStableDiffusionControlNetPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionPipeline
,
FlaxStableDiffusionPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
df91c447
...
@@ -30,5 +30,6 @@ if is_torch_available():
...
@@ -30,5 +30,6 @@ if is_torch_available():
from
.vq_model
import
VQModel
from
.vq_model
import
VQModel
if
is_flax_available
():
if
is_flax_available
():
from
.controlnet_flax
import
FlaxControlNetModel
from
.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.unet_2d_condition_flax
import
FlaxUNet2DConditionModel
from
.vae_flax
import
FlaxAutoencoderKL
from
.vae_flax
import
FlaxAutoencoderKL
src/diffusers/models/controlnet_flax.py
0 → 100644
View file @
df91c447
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Tuple
,
Union
import
flax
import
flax.linen
as
nn
import
jax
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
from
..configuration_utils
import
ConfigMixin
,
flax_register_to_config
from
..utils
import
BaseOutput
from
.embeddings_flax
import
FlaxTimestepEmbedding
,
FlaxTimesteps
from
.modeling_flax_utils
import
FlaxModelMixin
from
.unet_2d_blocks_flax
import
(
FlaxCrossAttnDownBlock2D
,
FlaxDownBlock2D
,
FlaxUNetMidBlock2DCrossAttn
,
)
@
flax
.
struct
.
dataclass
class
FlaxControlNetOutput
(
BaseOutput
):
down_block_res_samples
:
jnp
.
ndarray
mid_block_res_sample
:
jnp
.
ndarray
class
FlaxControlNetConditioningEmbedding
(
nn
.
Module
):
conditioning_embedding_channels
:
int
block_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
conv_in
=
nn
.
Conv
(
self
.
block_out_channels
[
0
],
kernel_size
=
(
3
,
3
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
blocks
=
[]
for
i
in
range
(
len
(
self
.
block_out_channels
)
-
1
):
channel_in
=
self
.
block_out_channels
[
i
]
channel_out
=
self
.
block_out_channels
[
i
+
1
]
conv1
=
nn
.
Conv
(
channel_in
,
kernel_size
=
(
3
,
3
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
blocks
.
append
(
conv1
)
conv2
=
nn
.
Conv
(
channel_out
,
kernel_size
=
(
3
,
3
),
strides
=
(
2
,
2
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
blocks
.
append
(
conv2
)
self
.
blocks
=
blocks
self
.
conv_out
=
nn
.
Conv
(
self
.
conditioning_embedding_channels
,
kernel_size
=
(
3
,
3
),
padding
=
((
1
,
1
),
(
1
,
1
)),
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
conditioning
):
embedding
=
self
.
conv_in
(
conditioning
)
embedding
=
nn
.
silu
(
embedding
)
for
block
in
self
.
blocks
:
embedding
=
block
(
embedding
)
embedding
=
nn
.
silu
(
embedding
)
embedding
=
self
.
conv_out
(
embedding
)
return
embedding
@
flax_register_to_config
class
FlaxControlNetModel
(
nn
.
Module
,
FlaxModelMixin
,
ConfigMixin
):
r
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
sample_size (`int`, *optional*):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
The channel order of conditional image. Will convert it to `rgb` if it's `bgr`
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in conditioning_embedding layer
"""
sample_size
:
int
=
32
in_channels
:
int
=
4
down_block_types
:
Tuple
[
str
]
=
(
"CrossAttnDownBlock2D"
,
"CrossAttnDownBlock2D"
,
"CrossAttnDownBlock2D"
,
"DownBlock2D"
,
)
only_cross_attention
:
Union
[
bool
,
Tuple
[
bool
]]
=
False
block_out_channels
:
Tuple
[
int
]
=
(
320
,
640
,
1280
,
1280
)
layers_per_block
:
int
=
2
attention_head_dim
:
Union
[
int
,
Tuple
[
int
]]
=
8
cross_attention_dim
:
int
=
1280
dropout
:
float
=
0.0
use_linear_projection
:
bool
=
False
dtype
:
jnp
.
dtype
=
jnp
.
float32
flip_sin_to_cos
:
bool
=
True
freq_shift
:
int
=
0
controlnet_conditioning_channel_order
:
str
=
"rgb"
conditioning_embedding_out_channels
:
Tuple
[
int
]
=
(
16
,
32
,
96
,
256
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
KeyArray
)
->
FrozenDict
:
# init input tensors
sample_shape
=
(
1
,
self
.
in_channels
,
self
.
sample_size
,
self
.
sample_size
)
sample
=
jnp
.
zeros
(
sample_shape
,
dtype
=
jnp
.
float32
)
timesteps
=
jnp
.
ones
((
1
,),
dtype
=
jnp
.
int32
)
encoder_hidden_states
=
jnp
.
zeros
((
1
,
1
,
self
.
cross_attention_dim
),
dtype
=
jnp
.
float32
)
controlnet_cond_shape
=
(
1
,
3
,
self
.
sample_size
*
8
,
self
.
sample_size
*
8
)
controlnet_cond
=
jnp
.
zeros
(
controlnet_cond_shape
,
dtype
=
jnp
.
float32
)
params_rng
,
dropout_rng
=
jax
.
random
.
split
(
rng
)
rngs
=
{
"params"
:
params_rng
,
"dropout"
:
dropout_rng
}
return
self
.
init
(
rngs
,
sample
,
timesteps
,
encoder_hidden_states
,
controlnet_cond
)[
"params"
]
def
setup
(
self
):
block_out_channels
=
self
.
block_out_channels
time_embed_dim
=
block_out_channels
[
0
]
*
4
# input
self
.
conv_in
=
nn
.
Conv
(
block_out_channels
[
0
],
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
((
1
,
1
),
(
1
,
1
)),
dtype
=
self
.
dtype
,
)
# time
self
.
time_proj
=
FlaxTimesteps
(
block_out_channels
[
0
],
flip_sin_to_cos
=
self
.
flip_sin_to_cos
,
freq_shift
=
self
.
config
.
freq_shift
)
self
.
time_embedding
=
FlaxTimestepEmbedding
(
time_embed_dim
,
dtype
=
self
.
dtype
)
self
.
controlnet_cond_embedding
=
FlaxControlNetConditioningEmbedding
(
conditioning_embedding_channels
=
block_out_channels
[
0
],
block_out_channels
=
self
.
conditioning_embedding_out_channels
,
)
only_cross_attention
=
self
.
only_cross_attention
if
isinstance
(
only_cross_attention
,
bool
):
only_cross_attention
=
(
only_cross_attention
,)
*
len
(
self
.
down_block_types
)
attention_head_dim
=
self
.
attention_head_dim
if
isinstance
(
attention_head_dim
,
int
):
attention_head_dim
=
(
attention_head_dim
,)
*
len
(
self
.
down_block_types
)
# down
down_blocks
=
[]
controlnet_down_blocks
=
[]
output_channel
=
block_out_channels
[
0
]
controlnet_block
=
nn
.
Conv
(
output_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
controlnet_down_blocks
.
append
(
controlnet_block
)
for
i
,
down_block_type
in
enumerate
(
self
.
down_block_types
):
input_channel
=
output_channel
output_channel
=
block_out_channels
[
i
]
is_final_block
=
i
==
len
(
block_out_channels
)
-
1
if
down_block_type
==
"CrossAttnDownBlock2D"
:
down_block
=
FlaxCrossAttnDownBlock2D
(
in_channels
=
input_channel
,
out_channels
=
output_channel
,
dropout
=
self
.
dropout
,
num_layers
=
self
.
layers_per_block
,
attn_num_head_channels
=
attention_head_dim
[
i
],
add_downsample
=
not
is_final_block
,
use_linear_projection
=
self
.
use_linear_projection
,
only_cross_attention
=
only_cross_attention
[
i
],
dtype
=
self
.
dtype
,
)
else
:
down_block
=
FlaxDownBlock2D
(
in_channels
=
input_channel
,
out_channels
=
output_channel
,
dropout
=
self
.
dropout
,
num_layers
=
self
.
layers_per_block
,
add_downsample
=
not
is_final_block
,
dtype
=
self
.
dtype
,
)
down_blocks
.
append
(
down_block
)
for
_
in
range
(
self
.
layers_per_block
):
controlnet_block
=
nn
.
Conv
(
output_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
controlnet_down_blocks
.
append
(
controlnet_block
)
if
not
is_final_block
:
controlnet_block
=
nn
.
Conv
(
output_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
controlnet_down_blocks
.
append
(
controlnet_block
)
self
.
down_blocks
=
down_blocks
self
.
controlnet_down_blocks
=
controlnet_down_blocks
# mid
mid_block_channel
=
block_out_channels
[
-
1
]
self
.
mid_block
=
FlaxUNetMidBlock2DCrossAttn
(
in_channels
=
mid_block_channel
,
dropout
=
self
.
dropout
,
attn_num_head_channels
=
attention_head_dim
[
-
1
],
use_linear_projection
=
self
.
use_linear_projection
,
dtype
=
self
.
dtype
,
)
self
.
controlnet_mid_block
=
nn
.
Conv
(
mid_block_channel
,
kernel_size
=
(
1
,
1
),
padding
=
"VALID"
,
kernel_init
=
nn
.
initializers
.
zeros_init
(),
bias_init
=
nn
.
initializers
.
zeros_init
(),
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
sample
,
timesteps
,
encoder_hidden_states
,
controlnet_cond
,
conditioning_scale
:
float
=
1.0
,
return_dict
:
bool
=
True
,
train
:
bool
=
False
,
)
->
Union
[
FlaxControlNetOutput
,
Tuple
]:
r
"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale: (`float`) the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
channel_order
=
self
.
controlnet_conditioning_channel_order
if
channel_order
==
"bgr"
:
controlnet_cond
=
jnp
.
flip
(
controlnet_cond
,
axis
=
1
)
# 1. time
if
not
isinstance
(
timesteps
,
jnp
.
ndarray
):
timesteps
=
jnp
.
array
([
timesteps
],
dtype
=
jnp
.
int32
)
elif
isinstance
(
timesteps
,
jnp
.
ndarray
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
.
astype
(
dtype
=
jnp
.
float32
)
timesteps
=
jnp
.
expand_dims
(
timesteps
,
0
)
t_emb
=
self
.
time_proj
(
timesteps
)
t_emb
=
self
.
time_embedding
(
t_emb
)
# 2. pre-process
sample
=
jnp
.
transpose
(
sample
,
(
0
,
2
,
3
,
1
))
sample
=
self
.
conv_in
(
sample
)
controlnet_cond
=
jnp
.
transpose
(
controlnet_cond
,
(
0
,
2
,
3
,
1
))
controlnet_cond
=
self
.
controlnet_cond_embedding
(
controlnet_cond
)
sample
+=
controlnet_cond
# 3. down
down_block_res_samples
=
(
sample
,)
for
down_block
in
self
.
down_blocks
:
if
isinstance
(
down_block
,
FlaxCrossAttnDownBlock2D
):
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
else
:
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
deterministic
=
not
train
)
down_block_res_samples
+=
res_samples
# 4. mid
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
# 5. contronet blocks
controlnet_down_block_res_samples
=
()
for
down_block_res_sample
,
controlnet_block
in
zip
(
down_block_res_samples
,
self
.
controlnet_down_blocks
):
down_block_res_sample
=
controlnet_block
(
down_block_res_sample
)
controlnet_down_block_res_samples
+=
(
down_block_res_sample
,)
down_block_res_samples
=
controlnet_down_block_res_samples
mid_block_res_sample
=
self
.
controlnet_mid_block
(
sample
)
# 6. scaling
down_block_res_samples
=
[
sample
*
conditioning_scale
for
sample
in
down_block_res_samples
]
mid_block_res_sample
*=
conditioning_scale
if
not
return_dict
:
return
(
down_block_res_samples
,
mid_block_res_sample
)
return
FlaxControlNetOutput
(
down_block_res_samples
=
down_block_res_samples
,
mid_block_res_sample
=
mid_block_res_sample
)
src/diffusers/models/unet_2d_condition_flax.py
View file @
df91c447
...
@@ -249,6 +249,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -249,6 +249,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample
,
sample
,
timesteps
,
timesteps
,
encoder_hidden_states
,
encoder_hidden_states
,
down_block_additional_residuals
=
None
,
mid_block_additional_residual
=
None
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
train
:
bool
=
False
,
train
:
bool
=
False
,
)
->
Union
[
FlaxUNet2DConditionOutput
,
Tuple
]:
)
->
Union
[
FlaxUNet2DConditionOutput
,
Tuple
]:
...
@@ -291,9 +293,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
...
@@ -291,9 +293,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
deterministic
=
not
train
)
sample
,
res_samples
=
down_block
(
sample
,
t_emb
,
deterministic
=
not
train
)
down_block_res_samples
+=
res_samples
down_block_res_samples
+=
res_samples
if
down_block_additional_residuals
is
not
None
:
new_down_block_res_samples
=
()
for
down_block_res_sample
,
down_block_additional_residual
in
zip
(
down_block_res_samples
,
down_block_additional_residuals
):
down_block_res_sample
+=
down_block_additional_residual
new_down_block_res_samples
+=
(
down_block_res_sample
,)
down_block_res_samples
=
new_down_block_res_samples
# 4. mid
# 4. mid
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
sample
=
self
.
mid_block
(
sample
,
t_emb
,
encoder_hidden_states
,
deterministic
=
not
train
)
if
mid_block_additional_residual
is
not
None
:
sample
+=
mid_block_additional_residual
# 5. up
# 5. up
for
up_block
in
self
.
up_blocks
:
for
up_block
in
self
.
up_blocks
:
res_samples
=
down_block_res_samples
[
-
(
self
.
layers_per_block
+
1
)
:]
res_samples
=
down_block_res_samples
[
-
(
self
.
layers_per_block
+
1
)
:]
...
...
src/diffusers/pipelines/__init__.py
View file @
df91c447
...
@@ -124,6 +124,7 @@ except OptionalDependencyNotAvailable:
...
@@ -124,6 +124,7 @@ except OptionalDependencyNotAvailable:
from
..utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
from
..utils.dummy_flax_and_transformers_objects
import
*
# noqa F403
else
:
else
:
from
.stable_diffusion
import
(
from
.stable_diffusion
import
(
FlaxStableDiffusionControlNetPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionImg2ImgPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionInpaintPipeline
,
FlaxStableDiffusionPipeline
,
FlaxStableDiffusionPipeline
,
...
...
src/diffusers/pipelines/pipeline_flax_utils.py
View file @
df91c447
...
@@ -278,7 +278,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -278,7 +278,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>>
sched, sched
_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
>>>
dpmpp, dpmpp
_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... model_id,
... subfolder="scheduler",
... subfolder="scheduler",
... )
... )
...
@@ -365,7 +365,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -365,7 +365,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# in this case they are already instantiated in `kwargs`
# extract them here
# extract them here
expected_modules
=
set
(
inspect
.
signature
(
pipeline_class
.
__init__
).
parameters
.
keys
()
)
expected_modules
,
optional_kwargs
=
cls
.
_get_
signature
_keys
(
pipeline_class
)
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
passed_class_obj
=
{
k
:
kwargs
.
pop
(
k
)
for
k
in
expected_modules
if
k
in
kwargs
}
init_dict
,
_
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
_
,
_
=
pipeline_class
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
@@ -470,6 +470,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -470,6 +470,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
# 4. Potentially add passed objects if expected
missing_modules
=
set
(
expected_modules
)
-
set
(
init_kwargs
.
keys
())
passed_modules
=
list
(
passed_class_obj
.
keys
())
if
len
(
missing_modules
)
>
0
and
missing_modules
<=
set
(
passed_modules
):
for
module
in
missing_modules
:
init_kwargs
[
module
]
=
passed_class_obj
.
get
(
module
,
None
)
elif
len
(
missing_modules
)
>
0
:
passed_modules
=
set
(
list
(
init_kwargs
.
keys
())
+
list
(
passed_class_obj
.
keys
()))
-
optional_kwargs
raise
ValueError
(
f
"Pipeline
{
pipeline_class
}
expected
{
expected_modules
}
, but only
{
passed_modules
}
were passed."
)
model
=
pipeline_class
(
**
init_kwargs
,
dtype
=
dtype
)
model
=
pipeline_class
(
**
init_kwargs
,
dtype
=
dtype
)
return
model
,
params
return
model
,
params
...
...
src/diffusers/pipelines/stable_diffusion/__init__.py
View file @
df91c447
...
@@ -127,6 +127,7 @@ if is_transformers_available() and is_flax_available():
...
@@ -127,6 +127,7 @@ if is_transformers_available() and is_flax_available():
from
...schedulers.scheduling_pndm_flax
import
PNDMSchedulerState
from
...schedulers.scheduling_pndm_flax
import
PNDMSchedulerState
from
.pipeline_flax_stable_diffusion
import
FlaxStableDiffusionPipeline
from
.pipeline_flax_stable_diffusion
import
FlaxStableDiffusionPipeline
from
.pipeline_flax_stable_diffusion_controlnet
import
FlaxStableDiffusionControlNetPipeline
from
.pipeline_flax_stable_diffusion_img2img
import
FlaxStableDiffusionImg2ImgPipeline
from
.pipeline_flax_stable_diffusion_img2img
import
FlaxStableDiffusionImg2ImgPipeline
from
.pipeline_flax_stable_diffusion_inpaint
import
FlaxStableDiffusionInpaintPipeline
from
.pipeline_flax_stable_diffusion_inpaint
import
FlaxStableDiffusionInpaintPipeline
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py
0 → 100644
View file @
df91c447
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
functools
import
partial
from
typing
import
Dict
,
List
,
Optional
,
Union
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
flax.core.frozen_dict
import
FrozenDict
from
flax.jax_utils
import
unreplicate
from
flax.training.common_utils
import
shard
from
PIL
import
Image
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
,
FlaxCLIPTextModel
from
...models
import
FlaxAutoencoderKL
,
FlaxControlNetModel
,
FlaxUNet2DConditionModel
from
...schedulers
import
(
FlaxDDIMScheduler
,
FlaxDPMSolverMultistepScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxPNDMScheduler
,
)
from
...utils
import
PIL_INTERPOLATION
,
logging
,
replace_example_docstring
from
..pipeline_flax_utils
import
FlaxDiffusionPipeline
from
.
import
FlaxStableDiffusionPipelineOutput
from
.safety_checker_flax
import
FlaxStableDiffusionSafetyChecker
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG
=
False
EXAMPLE_DOC_STRING
=
"""
Examples:
```py
>>> import jax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> from diffusers.utils import load_image
>>> from PIL import Image
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
>>> def image_grid(imgs, rows, cols):
... w, h = imgs[0].size
... grid = Image.new("RGB", size=(cols * w, rows * h))
... for i, img in enumerate(imgs):
... grid.paste(img, box=(i % cols * w, i // cols * h))
... return grid
>>> def create_key(seed=0):
... return jax.random.PRNGKey(seed)
>>> rng = create_key(0)
>>> # get canny image
>>> canny_image = load_image(
... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
... )
>>> prompts = "best quality, extremely detailed"
>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
>>> # load control net and stable diffusion v1-5
>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
... )
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
... )
>>> params["controlnet"] = controlnet_params
>>> num_samples = jax.device_count()
>>> rng = jax.random.split(rng, jax.device_count())
>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
>>> p_params = replicate(params)
>>> prompt_ids = shard(prompt_ids)
>>> negative_prompt_ids = shard(negative_prompt_ids)
>>> processed_image = shard(processed_image)
>>> output = pipe(
... prompt_ids=prompt_ids,
... image=processed_image,
... params=p_params,
... prng_seed=rng,
... num_inference_steps=50,
... neg_prompt_ids=negative_prompt_ids,
... jit=True,
... ).images
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
>>> output_images = image_grid(output_images, num_samples // 4, 4)
>>> output_images.save("generated_image.png")
```
"""
class
FlaxStableDiffusionControlNetPipeline
(
FlaxDiffusionPipeline
):
r
"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`FlaxCLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`FlaxControlNetModel`]:
Provides additional conditioning to the unet during the denoising process.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def
__init__
(
self
,
vae
:
FlaxAutoencoderKL
,
text_encoder
:
FlaxCLIPTextModel
,
tokenizer
:
CLIPTokenizer
,
unet
:
FlaxUNet2DConditionModel
,
controlnet
:
FlaxControlNetModel
,
scheduler
:
Union
[
FlaxDDIMScheduler
,
FlaxPNDMScheduler
,
FlaxLMSDiscreteScheduler
,
FlaxDPMSolverMultistepScheduler
],
safety_checker
:
FlaxStableDiffusionSafetyChecker
,
feature_extractor
:
CLIPFeatureExtractor
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
):
super
().
__init__
()
self
.
dtype
=
dtype
if
safety_checker
is
None
:
logger
.
warn
(
f
"You have disabled the safety checker for
{
self
.
__class__
}
by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self
.
register_modules
(
vae
=
vae
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
unet
=
unet
,
controlnet
=
controlnet
,
scheduler
=
scheduler
,
safety_checker
=
safety_checker
,
feature_extractor
=
feature_extractor
,
)
self
.
vae_scale_factor
=
2
**
(
len
(
self
.
vae
.
config
.
block_out_channels
)
-
1
)
def
prepare_text_inputs
(
self
,
prompt
:
Union
[
str
,
List
[
str
]]):
if
not
isinstance
(
prompt
,
(
str
,
list
)):
raise
ValueError
(
f
"`prompt` has to be of type `str` or `list` but is
{
type
(
prompt
)
}
"
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
self
.
tokenizer
.
model_max_length
,
truncation
=
True
,
return_tensors
=
"np"
,
)
return
text_input
.
input_ids
def
prepare_image_inputs
(
self
,
image
:
Union
[
Image
.
Image
,
List
[
Image
.
Image
]]):
if
not
isinstance
(
image
,
(
Image
.
Image
,
list
)):
raise
ValueError
(
f
"image has to be of type `PIL.Image.Image` or list but is
{
type
(
image
)
}
"
)
if
isinstance
(
image
,
Image
.
Image
):
image
=
[
image
]
processed_images
=
jnp
.
concatenate
([
preprocess
(
img
,
jnp
.
float32
)
for
img
in
image
])
return
processed_images
def
_get_has_nsfw_concepts
(
self
,
features
,
params
):
has_nsfw_concepts
=
self
.
safety_checker
(
features
,
params
)
return
has_nsfw_concepts
def
_run_safety_checker
(
self
,
images
,
safety_model_params
,
jit
=
False
):
# safety_model_params should already be replicated when jit is True
pil_images
=
[
Image
.
fromarray
(
image
)
for
image
in
images
]
features
=
self
.
feature_extractor
(
pil_images
,
return_tensors
=
"np"
).
pixel_values
if
jit
:
features
=
shard
(
features
)
has_nsfw_concepts
=
_p_get_has_nsfw_concepts
(
self
,
features
,
safety_model_params
)
has_nsfw_concepts
=
unshard
(
has_nsfw_concepts
)
safety_model_params
=
unreplicate
(
safety_model_params
)
else
:
has_nsfw_concepts
=
self
.
_get_has_nsfw_concepts
(
features
,
safety_model_params
)
images_was_copied
=
False
for
idx
,
has_nsfw_concept
in
enumerate
(
has_nsfw_concepts
):
if
has_nsfw_concept
:
if
not
images_was_copied
:
images_was_copied
=
True
images
=
images
.
copy
()
images
[
idx
]
=
np
.
zeros
(
images
[
idx
].
shape
,
dtype
=
np
.
uint8
)
# black image
if
any
(
has_nsfw_concepts
):
warnings
.
warn
(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return
images
,
has_nsfw_concepts
def
_generate
(
self
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
KeyArray
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
latents
:
Optional
[
jnp
.
array
]
=
None
,
neg_prompt_ids
:
Optional
[
jnp
.
array
]
=
None
,
controlnet_conditioning_scale
:
float
=
1.0
,
):
height
,
width
=
image
.
shape
[
-
2
:]
if
height
%
64
!=
0
or
width
%
64
!=
0
:
raise
ValueError
(
f
"`height` and `width` have to be divisible by 64 but are
{
height
}
and
{
width
}
."
)
# get prompt text embeddings
prompt_embeds
=
self
.
text_encoder
(
prompt_ids
,
params
=
params
[
"text_encoder"
])[
0
]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size
=
prompt_ids
.
shape
[
0
]
max_length
=
prompt_ids
.
shape
[
-
1
]
if
neg_prompt_ids
is
None
:
uncond_input
=
self
.
tokenizer
(
[
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
max_length
,
return_tensors
=
"np"
).
input_ids
else
:
uncond_input
=
neg_prompt_ids
negative_prompt_embeds
=
self
.
text_encoder
(
uncond_input
,
params
=
params
[
"text_encoder"
])[
0
]
context
=
jnp
.
concatenate
([
negative_prompt_embeds
,
prompt_embeds
])
image
=
jnp
.
concatenate
([
image
]
*
2
)
latents_shape
=
(
batch_size
,
self
.
unet
.
in_channels
,
height
//
self
.
vae_scale_factor
,
width
//
self
.
vae_scale_factor
,
)
if
latents
is
None
:
latents
=
jax
.
random
.
normal
(
prng_seed
,
shape
=
latents_shape
,
dtype
=
jnp
.
float32
)
else
:
if
latents
.
shape
!=
latents_shape
:
raise
ValueError
(
f
"Unexpected latents shape, got
{
latents
.
shape
}
, expected
{
latents_shape
}
"
)
def
loop_body
(
step
,
args
):
latents
,
scheduler_state
=
args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input
=
jnp
.
concatenate
([
latents
]
*
2
)
t
=
jnp
.
array
(
scheduler_state
.
timesteps
,
dtype
=
jnp
.
int32
)[
step
]
timestep
=
jnp
.
broadcast_to
(
t
,
latents_input
.
shape
[
0
])
latents_input
=
self
.
scheduler
.
scale_model_input
(
scheduler_state
,
latents_input
,
t
)
down_block_res_samples
,
mid_block_res_sample
=
self
.
controlnet
.
apply
(
{
"params"
:
params
[
"controlnet"
]},
jnp
.
array
(
latents_input
),
jnp
.
array
(
timestep
,
dtype
=
jnp
.
int32
),
encoder_hidden_states
=
context
,
controlnet_cond
=
image
,
conditioning_scale
=
controlnet_conditioning_scale
,
return_dict
=
False
,
)
# predict the noise residual
noise_pred
=
self
.
unet
.
apply
(
{
"params"
:
params
[
"unet"
]},
jnp
.
array
(
latents_input
),
jnp
.
array
(
timestep
,
dtype
=
jnp
.
int32
),
encoder_hidden_states
=
context
,
down_block_additional_residuals
=
down_block_res_samples
,
mid_block_additional_residual
=
mid_block_res_sample
,
).
sample
# perform guidance
noise_pred_uncond
,
noise_prediction_text
=
jnp
.
split
(
noise_pred
,
2
,
axis
=
0
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_prediction_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
,
scheduler_state
=
self
.
scheduler
.
step
(
scheduler_state
,
noise_pred
,
t
,
latents
).
to_tuple
()
return
latents
,
scheduler_state
scheduler_state
=
self
.
scheduler
.
set_timesteps
(
params
[
"scheduler"
],
num_inference_steps
=
num_inference_steps
,
shape
=
latents_shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents
=
latents
*
params
[
"scheduler"
].
init_noise_sigma
if
DEBUG
:
# run with python for loop
for
i
in
range
(
num_inference_steps
):
latents
,
scheduler_state
=
loop_body
(
i
,
(
latents
,
scheduler_state
))
else
:
latents
,
_
=
jax
.
lax
.
fori_loop
(
0
,
num_inference_steps
,
loop_body
,
(
latents
,
scheduler_state
))
# scale and decode the image latents with vae
latents
=
1
/
self
.
vae
.
config
.
scaling_factor
*
latents
image
=
self
.
vae
.
apply
({
"params"
:
params
[
"vae"
]},
latents
,
method
=
self
.
vae
.
decode
).
sample
image
=
(
image
/
2
+
0.5
).
clip
(
0
,
1
).
transpose
(
0
,
2
,
3
,
1
)
return
image
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
self
,
prompt_ids
:
jnp
.
array
,
image
:
jnp
.
array
,
params
:
Union
[
Dict
,
FrozenDict
],
prng_seed
:
jax
.
random
.
KeyArray
,
num_inference_steps
:
int
=
50
,
guidance_scale
:
Union
[
float
,
jnp
.
array
]
=
7.5
,
latents
:
jnp
.
array
=
None
,
neg_prompt_ids
:
jnp
.
array
=
None
,
controlnet_conditioning_scale
:
Union
[
float
,
jnp
.
array
]
=
1.0
,
return_dict
:
bool
=
True
,
jit
:
bool
=
False
,
):
r
"""
Function invoked when calling the pipeline for generation.
Args:
prompt_ids (`jnp.array`):
The prompt or prompts to guide the image generation.
image (`jnp.array`):
Array representing the ControlNet input condition. ControlNet use this input condition to generate
guidance to Unet.
params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
Examples:
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
height
,
width
=
image
.
shape
[
-
2
:]
if
isinstance
(
guidance_scale
,
float
):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale
=
jnp
.
array
([
guidance_scale
]
*
prompt_ids
.
shape
[
0
])
if
len
(
prompt_ids
.
shape
)
>
2
:
# Assume sharded
guidance_scale
=
guidance_scale
[:,
None
]
if
isinstance
(
controlnet_conditioning_scale
,
float
):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
controlnet_conditioning_scale
=
jnp
.
array
([
controlnet_conditioning_scale
]
*
prompt_ids
.
shape
[
0
])
if
len
(
prompt_ids
.
shape
)
>
2
:
# Assume sharded
controlnet_conditioning_scale
=
controlnet_conditioning_scale
[:,
None
]
if
jit
:
images
=
_p_generate
(
self
,
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
)
else
:
images
=
self
.
_generate
(
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
)
if
self
.
safety_checker
is
not
None
:
safety_params
=
params
[
"safety_checker"
]
images_uint8_casted
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
num_devices
,
batch_size
=
images
.
shape
[:
2
]
images_uint8_casted
=
np
.
asarray
(
images_uint8_casted
).
reshape
(
num_devices
*
batch_size
,
height
,
width
,
3
)
images_uint8_casted
,
has_nsfw_concept
=
self
.
_run_safety_checker
(
images_uint8_casted
,
safety_params
,
jit
)
images
=
np
.
asarray
(
images
)
# block images
if
any
(
has_nsfw_concept
):
for
i
,
is_nsfw
in
enumerate
(
has_nsfw_concept
):
if
is_nsfw
:
images
[
i
]
=
np
.
asarray
(
images_uint8_casted
[
i
])
images
=
images
.
reshape
(
num_devices
,
batch_size
,
height
,
width
,
3
)
else
:
images
=
np
.
asarray
(
images
)
has_nsfw_concept
=
False
if
not
return_dict
:
return
(
images
,
has_nsfw_concept
)
return
FlaxStableDiffusionPipelineOutput
(
images
=
images
,
nsfw_content_detected
=
has_nsfw_concept
)
# Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@
partial
(
jax
.
pmap
,
in_axes
=
(
None
,
0
,
0
,
0
,
0
,
None
,
0
,
0
,
0
,
0
),
static_broadcasted_argnums
=
(
0
,
5
),
)
def
_p_generate
(
pipe
,
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
):
return
pipe
.
_generate
(
prompt_ids
,
image
,
params
,
prng_seed
,
num_inference_steps
,
guidance_scale
,
latents
,
neg_prompt_ids
,
controlnet_conditioning_scale
,
)
@
partial
(
jax
.
pmap
,
static_broadcasted_argnums
=
(
0
,))
def
_p_get_has_nsfw_concepts
(
pipe
,
features
,
params
):
return
pipe
.
_get_has_nsfw_concepts
(
features
,
params
)
def
unshard
(
x
:
jnp
.
ndarray
):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices
,
batch_size
=
x
.
shape
[:
2
]
rest
=
x
.
shape
[
2
:]
return
x
.
reshape
(
num_devices
*
batch_size
,
*
rest
)
def
preprocess
(
image
,
dtype
):
image
=
image
.
convert
(
"RGB"
)
w
,
h
=
image
.
size
w
,
h
=
map
(
lambda
x
:
x
-
x
%
64
,
(
w
,
h
))
# resize to integer multiple of 64
image
=
image
.
resize
((
w
,
h
),
resample
=
PIL_INTERPOLATION
[
"lanczos"
])
image
=
jnp
.
array
(
image
).
astype
(
dtype
)
/
255.0
image
=
image
[
None
].
transpose
(
0
,
3
,
1
,
2
)
return
image
src/diffusers/utils/dummy_flax_and_transformers_objects.py
View file @
df91c447
...
@@ -2,6 +2,21 @@
...
@@ -2,6 +2,21 @@
from
..utils
import
DummyObject
,
requires_backends
from
..utils
import
DummyObject
,
requires_backends
class
FlaxStableDiffusionControlNetPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
,
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
,
"transformers"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
,
"transformers"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
,
"transformers"
])
class
FlaxStableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
class
FlaxStableDiffusionImg2ImgPipeline
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
,
"transformers"
]
_backends
=
[
"flax"
,
"transformers"
]
...
...
src/diffusers/utils/dummy_flax_objects.py
View file @
df91c447
...
@@ -2,6 +2,21 @@
...
@@ -2,6 +2,21 @@
from
..utils
import
DummyObject
,
requires_backends
from
..utils
import
DummyObject
,
requires_backends
class
FlaxControlNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
])
@
classmethod
def
from_config
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"flax"
])
class
FlaxModelMixin
(
metaclass
=
DummyObject
):
class
FlaxModelMixin
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
_backends
=
[
"flax"
]
...
...
tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py
0 → 100644
View file @
df91c447
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
from
diffusers
import
FlaxControlNetModel
,
FlaxStableDiffusionControlNetPipeline
from
diffusers.utils
import
is_flax_available
,
load_image
,
slow
from
diffusers.utils.testing_utils
import
require_flax
if
is_flax_available
():
import
jax
import
jax.numpy
as
jnp
from
flax.jax_utils
import
replicate
from
flax.training.common_utils
import
shard
@
slow
@
require_flax
class
FlaxStableDiffusionControlNetPipelineIntegrationTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
# clean up the VRAM after each test
super
().
tearDown
()
gc
.
collect
()
def
test_canny
(
self
):
controlnet
,
controlnet_params
=
FlaxControlNetModel
.
from_pretrained
(
"lllyasviel/sd-controlnet-canny"
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
pipe
,
params
=
FlaxStableDiffusionControlNetPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
controlnet
=
controlnet
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
params
[
"controlnet"
]
=
controlnet_params
prompts
=
"bird"
num_samples
=
jax
.
device_count
()
prompt_ids
=
pipe
.
prepare_text_inputs
([
prompts
]
*
num_samples
)
canny_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
processed_image
=
pipe
.
prepare_image_inputs
([
canny_image
]
*
num_samples
)
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
=
jax
.
random
.
split
(
rng
,
jax
.
device_count
())
p_params
=
replicate
(
params
)
prompt_ids
=
shard
(
prompt_ids
)
processed_image
=
shard
(
processed_image
)
images
=
pipe
(
prompt_ids
=
prompt_ids
,
image
=
processed_image
,
params
=
p_params
,
prng_seed
=
rng
,
num_inference_steps
=
50
,
jit
=
True
,
).
images
assert
images
.
shape
==
(
jax
.
device_count
(),
1
,
768
,
512
,
3
)
images
=
images
.
reshape
((
images
.
shape
[
0
]
*
images
.
shape
[
1
],)
+
images
.
shape
[
-
3
:])
image_slice
=
images
[
0
,
253
:
256
,
253
:
256
,
-
1
]
output_slice
=
jnp
.
asarray
(
jax
.
device_get
(
image_slice
.
flatten
()))
expected_slice
=
jnp
.
array
(
[
0.167969
,
0.116699
,
0.081543
,
0.154297
,
0.132812
,
0.108887
,
0.169922
,
0.169922
,
0.205078
]
)
print
(
f
"output_slice:
{
output_slice
}
"
)
assert
jnp
.
abs
(
output_slice
-
expected_slice
).
max
()
<
1e-2
def
test_pose
(
self
):
controlnet
,
controlnet_params
=
FlaxControlNetModel
.
from_pretrained
(
"lllyasviel/sd-controlnet-openpose"
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
pipe
,
params
=
FlaxStableDiffusionControlNetPipeline
.
from_pretrained
(
"runwayml/stable-diffusion-v1-5"
,
controlnet
=
controlnet
,
from_pt
=
True
,
dtype
=
jnp
.
bfloat16
)
params
[
"controlnet"
]
=
controlnet_params
prompts
=
"Chef in the kitchen"
num_samples
=
jax
.
device_count
()
prompt_ids
=
pipe
.
prepare_text_inputs
([
prompts
]
*
num_samples
)
pose_image
=
load_image
(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
)
processed_image
=
pipe
.
prepare_image_inputs
([
pose_image
]
*
num_samples
)
rng
=
jax
.
random
.
PRNGKey
(
0
)
rng
=
jax
.
random
.
split
(
rng
,
jax
.
device_count
())
p_params
=
replicate
(
params
)
prompt_ids
=
shard
(
prompt_ids
)
processed_image
=
shard
(
processed_image
)
images
=
pipe
(
prompt_ids
=
prompt_ids
,
image
=
processed_image
,
params
=
p_params
,
prng_seed
=
rng
,
num_inference_steps
=
50
,
jit
=
True
,
).
images
assert
images
.
shape
==
(
jax
.
device_count
(),
1
,
768
,
512
,
3
)
images
=
images
.
reshape
((
images
.
shape
[
0
]
*
images
.
shape
[
1
],)
+
images
.
shape
[
-
3
:])
image_slice
=
images
[
0
,
253
:
256
,
253
:
256
,
-
1
]
output_slice
=
jnp
.
asarray
(
jax
.
device_get
(
image_slice
.
flatten
()))
expected_slice
=
jnp
.
array
(
[[
0.271484
,
0.261719
,
0.275391
,
0.277344
,
0.279297
,
0.291016
,
0.294922
,
0.302734
,
0.302734
]]
)
print
(
f
"output_slice:
{
output_slice
}
"
)
assert
jnp
.
abs
(
output_slice
-
expected_slice
).
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment