Unverified Commit 1051ca81 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Stable Diffusion Latent Upscaler (#2059)



* Modify UNet2DConditionModel

- allow skipping mid_block

- adding a norm_group_size argument so that we can set the `num_groups` for group norm using `num_channels//norm_group_size`

- allow user to set dimension for the timestep embedding (`time_embed_dim`)

- the kernel_size for `conv_in` and `conv_out` is now configurable

- add random fourier feature layer (`GaussianFourierProjection`) for `time_proj`

- allow user to add the time and class embeddings before passing through the projection layer together - `time_embedding(t_emb + class_label))`

- added 2 arguments `attn1_types` and `attn2_types`

  * currently we have argument `only_cross_attention`: when it's set to `True`, we will have a to the
`BasicTransformerBlock` block with 2 cross-attention , otherwise we
get a self-attention followed by a cross-attention; in k-upscaler, we need to have blocks that include just one cross-attention, or self-attention -> cross-attention;
so I added `attn1_types` and `attn2_types` to the unet's argument list to allow user specify the attention types for the 2 positions in each block;  note that I stil kept
the `only_cross_attention` argument for unet for easy configuration, but it will be converted to `attn1_type` and `attn2_type` when passing down to the down blocks

- the position of downsample layer and upsample layer is now configurable

- in k-upscaler unet, there is only one skip connection per each up/down block (instead of each layer in stable diffusion unet), added `skip_freq = "block"` to support
this use case

- if user passes attention_mask to unet, it will prepare the mask and pass a flag to cross attention processer to skip the `prepare_attention_mask` step
inside cross attention block

add up/down blocks for k-upscaler

modify CrossAttention class

- make the `dropout` layer in `to_out` optional

- `use_conv_proj` - use conv instead of linear for all projection layers (i.e. `to_q`, `to_k`, `to_v`, `to_out`) whenever possible. note that when it's used to do cross
attention, to_k, to_v has to be linear because the `encoder_hidden_states` is not 2d

- `cross_attention_norm` - add an optional layernorm on encoder_hidden_states

- `attention_dropout`: add an optional dropout on attention score

adapt BasicTransformerBlock

- add an ada groupnorm layer  to conditioning attention input with timestep embedding

- allow skipping the FeedForward layer in between the attentions

- replaced the only_cross_attention argument with attn1_type and attn2_type for more flexible configuration

update timestep embedding: add new act_fn  gelu and an optional act_2

modified ResnetBlock2D

- refactored with AdaGroupNorm class (the timestep scale shift normalization)

- add `mid_channel` argument - allow the first conv to have a different output dimension from the second conv

- add option to use input AdaGroupNorm on the input instead of groupnorm

- add options to add a dropout layer after each conv

- allow user to set the bias in conv_shortcut (needed for k-upscaler)

- add gelu

adding conversion script for k-upscaler unet

add pipeline

* fix attention mask

* fix a typo

* fix a bug

* make sure model can be used with GPU

* make pipeline work with fp16

* fix an error in BasicTransfomerBlock

* make style

* fix typo

* some more fixes

* uP

* up

* correct more

* some clean-up

* clean time proj

* up

* uP

* more changes

* remove the upcast_attention=True from unet config

* remove attn1_types, attn2_types etc

* fix

* revert incorrect changes up/down samplers

* make style

* remove outdated files

* Apply suggestions from code review

* attention refactor

* refactor cross attention

* Apply suggestions from code review

* update

* up

* update

* Apply suggestions from code review

* finish

* Update src/diffusers/models/cross_attention.py

* more fixes

* up

* up

* up

* finish

* more corrections of conversion state

* act_2 -> act_2_fn

* remove dropout_after_conv from ResnetBlock2D

* make style

* simplify KAttentionBlock

* add fast test for latent upscaler pipeline

* add slow test

* slow test fp16

* make style

* add doc string for pipeline_stable_diffusion_latent_upscale

* add api doc page for latent upscaler pipeline

* deprecate attention mask

* clean up embeddings

* simplify resnet

* up

* clean up resnet

* up

* correct more

* up

* up

* improve a bit more

* correct more

* more clean-ups

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add docstrings for new unet config

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* # Copied from

* encode the image if not latent

* remove force casting vae to fp32

* fix

* add comments about preconditioning parameters from k-diffusion paper

* attn1_type, attn2_type -> add_self_attention

* clean up get_down_block and get_up_block

* fix

* fixed a typo(?) in ada group norm

* update slice attention processer for cross attention

* update slice

* fix fast test

* update the checkpoint

* finish tests

* fix-copies

* fix-copy for modeling_text_unet.py

* make style

* make style

* fix f-string

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix import

* correct changes

* fix resnet

* make fix-copies

* correct euler scheduler

* add missing #copied from for preprocess

* revert

* fix

* fix copies

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/cross_attention.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* clean up conversion script

* KDownsample2d,KUpsample2d -> KDownsample2D,KUpsample2D

* more

* Update src/diffusers/models/unet_2d_condition.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* remove prepare_extra_step_kwargs

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix a typo in timestep embedding

* remove num_image_per_prompt

* fix fasttest

* make style + fix-copies

* fix

* fix xformer test

* fix style

* doc string

* make style

* fix-copies

* docstring for time_embedding_norm

* make style

* final finishes

* make fix-copies

* fix tests

---------
Co-authored-by: default avataryiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 3b66cc0f
...@@ -145,6 +145,8 @@ ...@@ -145,6 +145,8 @@
title: Image-Variation title: Image-Variation
- local: api/pipelines/stable_diffusion/upscale - local: api/pipelines/stable_diffusion/upscale
title: Super-Resolution title: Super-Resolution
- local: api/pipelines/stable_diffusion/latent_upscale
title: Stable-Diffusion-Latent-Upscaler
- local: api/pipelines/stable_diffusion/pix2pix - local: api/pipelines/stable_diffusion/pix2pix
title: InstructPix2Pix title: InstructPix2Pix
title: Stable Diffusion title: Stable Diffusion
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Diffusion Latent Upscaler
## StableDiffusionLatentUpscalePipeline
The Stable Diffusion Latent Upscaler model was created by [Katherine Crowson](https://github.com/crowsonkb/k-diffusion) in collaboration with [Stability AI](https://stability.ai/). It can be used on top of any [`StableDiffusionUpscalePipeline`] checkpoint to enhance its output image resolution by a factor of 2.
A notebook that demonstrates the original implementation can be found here:
- [Stable Diffusion Upscaler Demo](https://colab.research.google.com/drive/1o1qYJcFeywzCIdkfKJy7cTpgZTCM2EI4)
Available Checkpoints are:
- *stabilityai/latent-upscaler*: [stabilityai/sd-x2-latent-upscaler](https://huggingface.co/stabilityai/sd-x2-latent-upscaler)
[[autodoc]] StableDiffusionLatentUpscalePipeline
- all
- __call__
- enable_sequential_cpu_offload
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
\ No newline at end of file
...@@ -31,6 +31,7 @@ For more details about how Stable Diffusion works and how it differs from the ba ...@@ -31,6 +31,7 @@ For more details about how Stable Diffusion works and how it differs from the ba
| [StableDiffusionDepth2ImgPipeline](./depth2img) | **Experimental** – *Depth-to-Image Text-Guided Generation * | | Coming soon | [StableDiffusionDepth2ImgPipeline](./depth2img) | **Experimental** – *Depth-to-Image Text-Guided Generation * | | Coming soon
| [StableDiffusionImageVariationPipeline](./image_variation) | **Experimental** – *Image Variation Generation * | | [🤗 Stable Diffusion Image Variations](https://huggingface.co/spaces/lambdalabs/stable-diffusion-image-variations) | [StableDiffusionImageVariationPipeline](./image_variation) | **Experimental** – *Image Variation Generation * | | [🤗 Stable Diffusion Image Variations](https://huggingface.co/spaces/lambdalabs/stable-diffusion-image-variations)
| [StableDiffusionUpscalePipeline](./upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon | [StableDiffusionUpscalePipeline](./upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon
| [StableDiffusionLatentUpscalePipeline](./latent_upscale) | **Experimental** – *Text-Guided Image Super-Resolution * | | Coming soon
| [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** – *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix) | [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** – *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix)
......
import argparse
import torch
import huggingface_hub
import k_diffusion as K
from diffusers import UNet2DConditionModel
UPSCALER_REPO = "pcuenq/k-upscaler"
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
rv = {
# norm1
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
# conv1
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
# norm2
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
# conv2
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
}
if resnet.conv_shortcut is not None:
rv.update(
{
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
}
)
return rv
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
rv = {
# norm
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
# to_q
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
# to_k
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
# to_v
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
# to_out
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
.squeeze(-1)
.squeeze(-1),
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
}
return rv
def cross_attn_to_diffusers_checkpoint(
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
):
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
rv = {
# norm2 (ada groupnorm)
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
f"{attention_prefix}.norm_dec.mapper.weight"
],
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
f"{attention_prefix}.norm_dec.mapper.bias"
],
# layernorm on encoder_hidden_state
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
f"{attention_prefix}.norm_enc.weight"
],
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
f"{attention_prefix}.norm_enc.bias"
],
# to_q
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
f"{attention_prefix}.q_proj.weight"
]
.squeeze(-1)
.squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
f"{attention_prefix}.q_proj.bias"
],
# to_k
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
# to_v
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
# to_out
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
f"{attention_prefix}.out_proj.weight"
]
.squeeze(-1)
.squeeze(-1),
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
f"{attention_prefix}.out_proj.bias"
],
}
return rv
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
block_prefix = f"{block_prefix}.{block_idx}"
diffusers_checkpoint = {}
if not hasattr(block, "attentions"):
n = 1 # resnet only
elif not block.attentions[0].add_self_attention:
n = 2 # resnet -> cross-attention
else:
n = 3 # resnet -> self-attention -> cross-attention)
for resnet_idx, resnet in enumerate(block.resnets):
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
diffusers_checkpoint.update(
resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
if hasattr(block, "attentions"):
for attention_idx, attention in enumerate(block.attentions):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx }"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
cross_attention_prefix = f"{block_prefix}.{idx }"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
diffusers_attention_index=2,
attention_prefix=cross_attention_prefix,
)
)
if attention.add_self_attention is True:
diffusers_checkpoint.update(
self_attn_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
attention_prefix=self_attention_prefix,
)
)
return diffusers_checkpoint
def unet_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# pre-processing
diffusers_checkpoint.update(
{
"conv_in.weight": checkpoint["inner_model.proj_in.weight"],
"conv_in.bias": checkpoint["inner_model.proj_in.bias"],
}
)
# timestep and class embedding
diffusers_checkpoint.update(
{
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
}
)
# down_blocks
for down_block_idx, down_block in enumerate(model.down_blocks):
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
# up_blocks
for up_block_idx, up_block in enumerate(model.up_blocks):
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
# post-processing
diffusers_checkpoint.update(
{
"conv_out.weight": checkpoint["inner_model.proj_out.weight"],
"conv_out.bias": checkpoint["inner_model.proj_out.bias"],
}
)
return diffusers_checkpoint
def unet_model_from_original_config(original_config):
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
block_out_channels = original_config["channels"]
assert (
len(set(original_config["depths"])) == 1
), "UNet2DConditionModel currently do not support blocks with different number of layers"
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
cross_attention_dim = original_config["cross_cond_dim"]
attn1_types = []
attn2_types = []
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
if s:
a1 = "self"
a2 = "cross" if c else None
elif c:
a1 = "cross"
a2 = None
else:
a1 = None
a2 = None
attn1_types.append(a1)
attn2_types.append(a2)
unet = UNet2DConditionModel(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
mid_block_type=None,
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn="gelu",
norm_num_groups=None,
cross_attention_dim=cross_attention_dim,
attention_head_dim=64,
time_cond_proj_dim=class_labels_dim,
resnet_time_scale_shift="scale_shift",
time_embedding_type="fourier",
timestep_post_act="gelu",
conv_in_kernel=1,
conv_out_kernel=1,
)
return unet
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
orig_weights_path = huggingface_hub.hf_hub_download(
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
)
print(f"loading original model configuration from {orig_config_path}")
print(f"loading original model checkpoint from {orig_weights_path}")
print("converting to diffusers unet")
orig_config = K.config.load_config(open(orig_config_path))["model"]
model = unet_model_from_original_config(orig_config)
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
model.load_state_dict(converted_checkpoint, strict=True)
model.save_pretrained(args.dump_path)
print(f"saving converted unet model in {args.dump_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
...@@ -115,6 +115,7 @@ else: ...@@ -115,6 +115,7 @@ else:
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy, StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionPipelineSafe, StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
......
...@@ -480,3 +480,38 @@ class AdaLayerNormZero(nn.Module): ...@@ -480,3 +480,38 @@ class AdaLayerNormZero(nn.Module):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaGroupNorm(nn.Module):
"""
GroupNorm layer modified to incorporate timestep embeddings.
"""
def __init__(
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.act = None
if act_fn == "swish":
self.act = lambda x: F.silu(x)
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "gelu":
self.act = nn.GELU()
self.linear = nn.Linear(embedding_dim, out_dim * 2)
def forward(self, x, emb):
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, dim=1)
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import logging from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
...@@ -56,6 +56,7 @@ class CrossAttention(nn.Module): ...@@ -56,6 +56,7 @@ class CrossAttention(nn.Module):
bias=False, bias=False,
upcast_attention: bool = False, upcast_attention: bool = False,
upcast_softmax: bool = False, upcast_softmax: bool = False,
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None, norm_num_groups: Optional[int] = None,
processor: Optional["AttnProcessor"] = None, processor: Optional["AttnProcessor"] = None,
...@@ -65,6 +66,7 @@ class CrossAttention(nn.Module): ...@@ -65,6 +66,7 @@ class CrossAttention(nn.Module):
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
...@@ -81,6 +83,9 @@ class CrossAttention(nn.Module): ...@@ -81,6 +83,9 @@ class CrossAttention(nn.Module):
else: else:
self.group_norm = None self.group_norm = None
if cross_attention_norm:
self.norm_cross = nn.LayerNorm(cross_attention_dim)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
...@@ -224,7 +229,19 @@ class CrossAttention(nn.Module): ...@@ -224,7 +229,19 @@ class CrossAttention(nn.Module):
return attention_probs return attention_probs
def prepare_attention_mask(self, attention_mask, target_length): def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
if batch_size is None:
deprecate(
"batch_size=None",
"0.0.15",
message=(
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
" `prepare_attention_mask` when preparing the attention_mask."
),
)
batch_size = 1
head_size = self.heads head_size = self.heads
if attention_mask is None: if attention_mask is None:
return attention_mask return attention_mask
...@@ -238,18 +255,29 @@ class CrossAttention(nn.Module): ...@@ -238,18 +255,29 @@ class CrossAttention(nn.Module):
attention_mask = torch.concat([attention_mask, padding], dim=2) attention_mask = torch.concat([attention_mask, padding], dim=2)
else: else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask return attention_mask
class CrossAttnProcessor: class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -305,7 +333,7 @@ class LoRACrossAttnProcessor(nn.Module): ...@@ -305,7 +333,7 @@ class LoRACrossAttnProcessor(nn.Module):
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
): ):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
...@@ -337,7 +365,7 @@ class CrossAttnAddedKVProcessor: ...@@ -337,7 +365,7 @@ class CrossAttnAddedKVProcessor:
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
...@@ -379,11 +407,15 @@ class XFormersCrossAttnProcessor: ...@@ -379,11 +407,15 @@ class XFormersCrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -417,7 +449,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module): ...@@ -417,7 +449,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
): ):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
...@@ -448,13 +480,17 @@ class SlicedAttnProcessor: ...@@ -448,13 +480,17 @@ class SlicedAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
dim = query.shape[-1] dim = query.shape[-1]
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
...@@ -500,7 +536,7 @@ class SlicedAttnAddedKVProcessor: ...@@ -500,7 +536,7 @@ class SlicedAttnAddedKVProcessor:
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -152,15 +153,32 @@ class PatchEmbed(nn.Module): ...@@ -152,15 +153,32 @@ class PatchEmbed(nn.Module):
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim) self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = None
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
if act_fn == "silu": if act_fn == "silu":
self.act = nn.SiLU() self.act = nn.SiLU()
elif act_fn == "mish": elif act_fn == "mish":
self.act = nn.Mish() self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
if out_dim is not None: if out_dim is not None:
time_embed_dim_out = out_dim time_embed_dim_out = out_dim
...@@ -168,13 +186,29 @@ class TimestepEmbedding(nn.Module): ...@@ -168,13 +186,29 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = time_embed_dim time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
def forward(self, sample): if post_act_fn is None:
self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample) sample = self.linear_1(sample)
if self.act is not None: if self.act is not None:
sample = self.act(sample) sample = self.act(sample)
sample = self.linear_2(sample) sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample return sample
......
from functools import partial from functools import partial
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .attention import AdaGroupNorm
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
""" """
...@@ -364,7 +367,70 @@ class FirDownsample2D(nn.Module): ...@@ -364,7 +367,70 @@ class FirDownsample2D(nn.Module):
return hidden_states return hidden_states
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class KUpsample2D(nn.Module):
def __init__(self, pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
class ResnetBlock2D(nn.Module): class ResnetBlock2D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__( def __init__(
self, self,
*, *,
...@@ -378,12 +444,14 @@ class ResnetBlock2D(nn.Module): ...@@ -378,12 +444,14 @@ class ResnetBlock2D(nn.Module):
pre_norm=True, pre_norm=True,
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
time_embedding_norm="default", time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None, kernel=None,
output_scale_factor=1.0, output_scale_factor=1.0,
use_in_shortcut=None, use_in_shortcut=None,
up=False, up=False,
down=False, down=False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
): ):
super().__init__() super().__init__()
self.pre_norm = pre_norm self.pre_norm = pre_norm
...@@ -392,40 +460,50 @@ class ResnetBlock2D(nn.Module): ...@@ -392,40 +460,50 @@ class ResnetBlock2D(nn.Module):
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up self.up = up
self.down = down self.down = down
self.output_scale_factor = output_scale_factor self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None: if temb_channels is not None:
if self.time_embedding_norm == "default": if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift": elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2 self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group":
self.time_emb_proj = None
else: else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else: else:
self.time_emb_proj = None self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish": if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x) self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish": elif non_linearity == "mish":
self.nonlinearity = Mish() self.nonlinearity = nn.Mish()
elif non_linearity == "silu": elif non_linearity == "silu":
self.nonlinearity = nn.SiLU() self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
self.upsample = self.downsample = None self.upsample = self.downsample = None
if self.up: if self.up:
...@@ -445,16 +523,22 @@ class ResnetBlock2D(nn.Module): ...@@ -445,16 +523,22 @@ class ResnetBlock2D(nn.Module):
else: else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None self.conv_shortcut = None
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.conv_shortcut = torch.nn.Conv2d(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
def forward(self, input_tensor, temb): def forward(self, input_tensor, temb):
hidden_states = input_tensor hidden_states = input_tensor
hidden_states = self.norm1(hidden_states) if self.time_embedding_norm == "ada_group":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None: if self.upsample is not None:
...@@ -470,13 +554,16 @@ class ResnetBlock2D(nn.Module): ...@@ -470,13 +554,16 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
if temb is not None: if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states) if self.time_embedding_norm == "ada_group":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift": if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1) scale, shift = torch.chunk(temb, 2, dim=1)
......
This diff is collapsed.
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .cross_attention import AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
...@@ -70,9 +70,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -70,9 +70,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
mid block layer if `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
...@@ -80,6 +84,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -80,6 +84,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
...@@ -90,6 +95,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -90,6 +95,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
The dimension of `cond_proj` layer in timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -109,7 +122,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -109,7 +122,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: str = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
...@@ -117,7 +130,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -117,7 +130,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
act_fn: str = "silu", act_fn: str = "silu",
norm_num_groups: int = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
...@@ -127,20 +140,48 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -127,20 +140,48 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
time_embedding_type: str = "positional", # fourier, positional
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
): ):
super().__init__() super().__init__()
self.sample_size = sample_size self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input # input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time # time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) if time_embedding_type == "fourier":
timestep_input_dim = block_out_channels[0] time_embed_dim = block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
...@@ -153,7 +194,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -153,7 +194,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.class_embedding = None self.class_embedding = None
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool): if isinstance(only_cross_attention, bool):
...@@ -218,6 +258,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -218,6 +258,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
) )
elif mid_block_type is None:
self.mid_block = None
else: else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}") raise ValueError(f"unknown mid_block_type : {mid_block_type}")
...@@ -228,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -228,6 +270,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
...@@ -266,9 +309,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -266,9 +309,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) if norm_num_groups is not None:
self.conv_act = nn.SiLU() self.conv_norm_out = nn.GroupNorm(
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property @property
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttnProcessor]:
...@@ -399,6 +452,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -399,6 +452,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
...@@ -466,7 +520,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -466,7 +520,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
...@@ -498,13 +553,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -498,13 +553,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid # 4. mid
sample = self.mid_block( if self.mid_block is not None:
sample, sample = self.mid_block(
emb, sample,
encoder_hidden_states=encoder_hidden_states, emb,
attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask,
) cross_attention_kwargs=cross_attention_kwargs,
)
# 5. up # 5. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
...@@ -533,8 +589,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -533,8 +589,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
) )
# 6. post-process # 6. post-process
sample = self.conv_norm_out(sample) if self.conv_norm_out:
sample = self.conv_act(sample) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if not return_dict: if not return_dict:
......
...@@ -52,6 +52,7 @@ else: ...@@ -52,6 +52,7 @@ else:
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy, StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
) )
......
...@@ -632,15 +632,24 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -632,15 +632,24 @@ class AltDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing if output_type == "latent":
image = self.decode_latents(latents) image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker # 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL # 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -44,6 +44,7 @@ if is_transformers_available() and is_torch_available(): ...@@ -44,6 +44,7 @@ if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -629,15 +629,24 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -629,15 +629,24 @@ class StableDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# 8. Post-processing if output_type == "latent":
image = self.decode_latents(latents) image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker # 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# 10. Convert to PIL # 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -9,7 +9,7 @@ from ...models import ModelMixin ...@@ -9,7 +9,7 @@ from ...models import ModelMixin
from ...models.attention import CrossAttention from ...models.attention import CrossAttention
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging from ...utils import logging
...@@ -151,9 +151,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -151,9 +151,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use. The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`. The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip
the mid block layer if `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
The tuple of upsample blocks to use. The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block. The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
...@@ -161,6 +165,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -161,6 +165,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
...@@ -171,6 +176,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -171,6 +176,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
The dimension of `cond_proj` layer in timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -190,7 +203,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -190,7 +203,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"DownBlockFlat", "DownBlockFlat",
), ),
mid_block_type: str = "UNetMidBlockFlatCrossAttn", mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = ( up_block_types: Tuple[str] = (
"UpBlockFlat", "UpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
...@@ -203,7 +216,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -203,7 +216,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
act_fn: str = "silu", act_fn: str = "silu",
norm_num_groups: int = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
...@@ -213,20 +226,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -213,20 +226,48 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
time_embedding_type: str = "positional", # fourier, positional
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
): ):
super().__init__() super().__init__()
self.sample_size = sample_size self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input # input
self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = LinearMultiDim(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time # time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) if time_embedding_type == "fourier":
timestep_input_dim = block_out_channels[0] time_embed_dim = block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
...@@ -239,7 +280,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -239,7 +280,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.class_embedding = None self.class_embedding = None
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool): if isinstance(only_cross_attention, bool):
...@@ -304,6 +344,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -304,6 +344,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
) )
elif mid_block_type is None:
self.mid_block = None
else: else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}") raise ValueError(f"unknown mid_block_type : {mid_block_type}")
...@@ -314,6 +356,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -314,6 +356,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim)) reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
...@@ -352,9 +395,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -352,9 +395,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) if norm_num_groups is not None:
self.conv_act = nn.SiLU() self.conv_norm_out = nn.GroupNorm(
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = LinearMultiDim(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property @property
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttnProcessor]:
...@@ -485,6 +538,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -485,6 +538,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
...@@ -552,7 +606,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -552,7 +606,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
...@@ -584,13 +639,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -584,13 +639,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 4. mid # 4. mid
sample = self.mid_block( if self.mid_block is not None:
sample, sample = self.mid_block(
emb, sample,
encoder_hidden_states=encoder_hidden_states, emb,
attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask,
) cross_attention_kwargs=cross_attention_kwargs,
)
# 5. up # 5. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
...@@ -619,8 +675,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -619,8 +675,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
) )
# 6. post-process # 6. post-process
sample = self.conv_norm_out(sample) if self.conv_norm_out:
sample = self.conv_act(sample) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if not return_dict: if not return_dict:
......
...@@ -65,11 +65,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -65,11 +65,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, default `"epsilon"`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
interpolation_type (`str`, default `"linear"`, optional):
interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
[`"linear"`, `"log_linear"`].
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -84,6 +86,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,6 +86,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
interpolation_type: str = "linear",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -130,7 +133,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,7 +133,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep = timestep.to(self.timesteps.device) timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item() step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index] sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True self.is_scale_input_called = True
return sample return sample
...@@ -148,7 +153,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -148,7 +153,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
else:
raise ValueError(
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"): if str(device).startswith("mps"):
...@@ -230,7 +245,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -230,7 +245,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "original_sample":
pred_original_sample = model_output
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip # * c_out + input * c_skip
......
...@@ -169,6 +169,21 @@ class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): ...@@ -169,6 +169,21 @@ class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPipeline(metaclass=DummyObject): class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -252,7 +252,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -252,7 +252,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment