Unverified Commit 40398152 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

open muse (#5437)



amused

rename

Update docs/source/en/api/pipelines/amused.md
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

AdaLayerNormContinuous default values

custom micro conditioning

micro conditioning docs

put lookup from codebook in constructor

fix conversion script

remove manual fused flash attn kernel

add training script

temp remove training script

add dummy gradient checkpointing func

clarify temperatures is an instance variable by setting it

remove additional SkipFF block args

hardcode norm args

rename tests folder

fix paths and samples

fix tests

add training script

training readme

lora saving and loading

non-lora saving/loading

some readme fixes

guards

Update docs/source/en/api/pipelines/amused.md
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

Update examples/amused/README.md
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

Update examples/amused/train_amused.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

vae upcasting

add fp16 integration tests

use tuple for micro cond

copyrights

remove casts

delegate to torch.nn.LayerNorm

move temperature to pipeline call

upsampling/downsampling changes
parent 5b186b71
...@@ -244,6 +244,8 @@ ...@@ -244,6 +244,8 @@
- sections: - sections:
- local: api/pipelines/overview - local: api/pipelines/overview
title: Overview title: Overview
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff - local: api/pipelines/animatediff
title: AnimateDiff title: AnimateDiff
- local: api/pipelines/attend_and_excite - local: api/pipelines/attend_and_excite
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# aMUSEd
Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.
| Model | Params |
|-------|--------|
| [amused-256](https://huggingface.co/huggingface/amused-256) | 603M |
| [amused-512](https://huggingface.co/huggingface/amused-512) | 608M |
## AmusedPipeline
[[autodoc]] AmusedPipeline
- __call__
- all
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
\ No newline at end of file
## Amused training
Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates.
All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size).
### Finetuning the 256 checkpoint
These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset.
Example results:
![noun1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun1.png) ![noun2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun2.png) ![noun3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun3.png)
#### Full finetuning
Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 19.7 GB |
| 4 | 2 | 8 | 18.3 GB |
| 1 | 8 | 8 | 17.9 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 1e-4 \
--pretrained_model_name_or_path huggingface/amused-256 \
--instance_data_dataset 'm1guelpf/nouns' \
--image_key image \
--prompt_key text \
--resolution 256 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
'a pixel art character with square red glasses' \
'a pixel art character' \
'square red glasses on a pixel art character' \
'square red glasses on a pixel art character with a baseball-shaped head' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + 8 bit adam
Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate.
Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 16 | 1 | 16 | 20.1 GB |
| 8 | 2 | 16 | 15.6 GB |
| 1 | 16 | 16 | 10.7 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 2e-5 \
--use_8bit_adam \
--pretrained_model_name_or_path huggingface/amused-256 \
--instance_data_dataset 'm1guelpf/nouns' \
--image_key image \
--prompt_key text \
--resolution 256 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
'a pixel art character with square red glasses' \
'a pixel art character' \
'square red glasses on a pixel art character' \
'square red glasses on a pixel art character with a baseball-shaped head' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + lora
Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 16 | 1 | 16 | 14.1 GB |
| 8 | 2 | 16 | 10.1 GB |
| 1 | 16 | 16 | 6.5 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 8e-4 \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-256 \
--instance_data_dataset 'm1guelpf/nouns' \
--image_key image \
--prompt_key text \
--resolution 256 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
'a pixel art character with square red glasses' \
'a pixel art character' \
'square red glasses on a pixel art character' \
'square red glasses on a pixel art character with a baseball-shaped head' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
### Finetuning the 512 checkpoint
These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset.
Example results:
![minecraft1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft1.png) ![minecraft2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft2.png) ![minecraft3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft3.png)
#### Full finetuning
Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 24.2 GB |
| 4 | 2 | 8 | 19.7 GB |
| 1 | 8 | 8 | 16.99 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 8e-5 \
--pretrained_model_name_or_path huggingface/amused-512 \
--instance_data_dataset 'monadical-labs/minecraft-preview' \
--prompt_prefix 'minecraft ' \
--image_key image \
--prompt_key text \
--resolution 512 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'minecraft Avatar' \
'minecraft character' \
'minecraft' \
'minecraft president' \
'minecraft pig' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + 8 bit adam
Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 21.2 GB |
| 4 | 2 | 8 | 13.3 GB |
| 1 | 8 | 8 | 9.9 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 5e-6 \
--pretrained_model_name_or_path huggingface/amused-512 \
--instance_data_dataset 'monadical-labs/minecraft-preview' \
--prompt_prefix 'minecraft ' \
--image_key image \
--prompt_key text \
--resolution 512 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'minecraft Avatar' \
'minecraft character' \
'minecraft' \
'minecraft president' \
'minecraft pig' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
#### Full finetuning + lora
Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|------------|-----------------------------|------------------|-------------|
| 8 | 1 | 8 | 12.7 GB |
| 4 | 2 | 8 | 9.0 GB |
| 1 | 8 | 8 | 5.6 GB |
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--train_batch_size <batch size> \
--gradient_accumulation_steps <gradient accumulation steps> \
--learning_rate 1e-4 \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-512 \
--instance_data_dataset 'monadical-labs/minecraft-preview' \
--prompt_prefix 'minecraft ' \
--image_key image \
--prompt_key text \
--resolution 512 \
--mixed_precision fp16 \
--lr_scheduler constant \
--validation_prompts \
'minecraft Avatar' \
'minecraft character' \
'minecraft' \
'minecraft president' \
'minecraft pig' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 250 \
--gradient_checkpointing
```
### Styledrop
[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image.
This is our example style image:
![example](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png)
Download it to your local directory with
```sh
wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png
```
#### 256
Example results:
![glowing_256_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_1.png) ![glowing_256_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_2.png) ![glowing_256_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_3.png)
Learning rate: 4e-4, Gives decent results in 1500-2000 steps
Memory used: 6.5 GB
```sh
accelerate launch train_amused.py \
--output_dir <output path> \
--mixed_precision fp16 \
--report_to wandb \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-256 \
--train_batch_size 1 \
--lr_scheduler constant \
--learning_rate 4e-4 \
--validation_prompts \
'A chihuahua walking on the street in [V] style' \
'A banana on the table in [V] style' \
'A church on the street in [V] style' \
'A tabby cat walking in the forest in [V] style' \
--instance_data_image 'A mushroom in [V] style.png' \
--max_train_steps 10000 \
--checkpointing_steps 500 \
--validation_steps 100 \
--resolution 256
```
#### 512
Example results:
![glowing_512_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_1.png) ![glowing_512_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_2.png) ![glowing_512_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_3.png)
Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps
Memory used: 5.6 GB
```
accelerate launch train_amused.py \
--output_dir <output path> \
--mixed_precision fp16 \
--report_to wandb \
--use_lora \
--pretrained_model_name_or_path huggingface/amused-512 \
--train_batch_size 1 \
--lr_scheduler constant \
--learning_rate 1e-3 \
--validation_prompts \
'A chihuahua walking on the street in [V] style' \
'A banana on the table in [V] style' \
'A church on the street in [V] style' \
'A tabby cat walking in the forest in [V] style' \
--instance_data_image 'A mushroom in [V] style.png' \
--max_train_steps 100000 \
--checkpointing_steps 500 \
--validation_steps 100 \
--resolution 512 \
--lora_alpha 1
```
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -95,6 +95,7 @@ else: ...@@ -95,6 +95,7 @@ else:
"UNet3DConditionModel", "UNet3DConditionModel",
"UNetMotionModel", "UNetMotionModel",
"UNetSpatioTemporalConditionModel", "UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel", "VQModel",
] ]
) )
...@@ -131,6 +132,7 @@ else: ...@@ -131,6 +132,7 @@ else:
) )
_import_structure["schedulers"].extend( _import_structure["schedulers"].extend(
[ [
"AmusedScheduler",
"CMStochasticIterativeScheduler", "CMStochasticIterativeScheduler",
"DDIMInverseScheduler", "DDIMInverseScheduler",
"DDIMParallelScheduler", "DDIMParallelScheduler",
...@@ -202,6 +204,9 @@ else: ...@@ -202,6 +204,9 @@ else:
[ [
"AltDiffusionImg2ImgPipeline", "AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline", "AltDiffusionPipeline",
"AmusedImg2ImgPipeline",
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffPipeline", "AnimateDiffPipeline",
"AudioLDM2Pipeline", "AudioLDM2Pipeline",
"AudioLDM2ProjectionModel", "AudioLDM2ProjectionModel",
...@@ -472,6 +477,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -472,6 +477,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet3DConditionModel, UNet3DConditionModel,
UNetMotionModel, UNetMotionModel,
UNetSpatioTemporalConditionModel, UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel, VQModel,
) )
from .optimization import ( from .optimization import (
...@@ -506,6 +512,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -506,6 +512,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ScoreSdeVePipeline, ScoreSdeVePipeline,
) )
from .schedulers import ( from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler, CMStochasticIterativeScheduler,
DDIMInverseScheduler, DDIMInverseScheduler,
DDIMParallelScheduler, DDIMParallelScheduler,
...@@ -560,6 +567,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -560,6 +567,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipelines import ( from .pipelines import (
AltDiffusionImg2ImgPipeline, AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline, AltDiffusionPipeline,
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffPipeline, AnimateDiffPipeline,
AudioLDM2Pipeline, AudioLDM2Pipeline,
AudioLDM2ProjectionModel, AudioLDM2ProjectionModel,
......
...@@ -59,6 +59,7 @@ logger = logging.get_logger(__name__) ...@@ -59,6 +59,7 @@ logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder" TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet" UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
...@@ -74,6 +75,7 @@ class LoraLoaderMixin: ...@@ -74,6 +75,7 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME unet_name = UNET_NAME
transformer_name = TRANSFORMER_NAME
num_fused_loras = 0 num_fused_loras = 0
def load_lora_weights( def load_lora_weights(
...@@ -661,6 +663,89 @@ class LoraLoaderMixin: ...@@ -661,6 +663,89 @@ class LoraLoaderMixin:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
# Unsafe code /> # Unsafe code />
@classmethod
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
network_alphas = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@property @property
def lora_scale(self) -> float: def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline. # property function that returns the lora scale which can be set at run time by the pipeline.
...@@ -786,6 +871,7 @@ class LoraLoaderMixin: ...@@ -786,6 +871,7 @@ class LoraLoaderMixin:
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
...@@ -820,8 +906,10 @@ class LoraLoaderMixin: ...@@ -820,8 +906,10 @@ class LoraLoaderMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`."
)
if unet_lora_layers: if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet")) state_dict.update(pack_weights(unet_lora_layers, "unet"))
...@@ -829,6 +917,9 @@ class LoraLoaderMixin: ...@@ -829,6 +917,9 @@ class LoraLoaderMixin:
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
......
...@@ -47,6 +47,7 @@ if is_torch_available(): ...@@ -47,6 +47,7 @@ if is_torch_available():
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"] _import_structure["vq_model"] = ["VQModel"]
if is_flax_available(): if is_flax_available():
...@@ -81,6 +82,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -81,6 +82,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_kandinsky3 import Kandinsky3UNet from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .uvit_2d import UVit2DModel
from .vq_model import VQModel from .vq_model import VQModel
if is_flax_available(): if is_flax_available():
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND from ..utils import USE_PEFT_BACKEND
...@@ -22,7 +23,7 @@ from .activations import GEGLU, GELU, ApproximateGELU ...@@ -22,7 +23,7 @@ from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
def _chunked_feed_forward( def _chunked_feed_forward(
...@@ -148,6 +149,11 @@ class BasicTransformerBlock(nn.Module): ...@@ -148,6 +149,11 @@ class BasicTransformerBlock(nn.Module):
attention_type: str = "default", attention_type: str = "default",
positional_embeddings: Optional[str] = None, positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None, num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
): ):
super().__init__() super().__init__()
self.only_cross_attention = only_cross_attention self.only_cross_attention = only_cross_attention
...@@ -156,6 +162,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -156,6 +162,7 @@ class BasicTransformerBlock(nn.Module):
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single" self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm" self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError( raise ValueError(
...@@ -179,6 +186,15 @@ class BasicTransformerBlock(nn.Module): ...@@ -179,6 +186,15 @@ class BasicTransformerBlock(nn.Module):
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero: elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else: else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
...@@ -190,6 +206,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -190,6 +206,7 @@ class BasicTransformerBlock(nn.Module):
bias=attention_bias, bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None, cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) )
# 2. Cross-Attn # 2. Cross-Attn
...@@ -197,11 +214,20 @@ class BasicTransformerBlock(nn.Module): ...@@ -197,11 +214,20 @@ class BasicTransformerBlock(nn.Module):
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block. # the second cross attention block.
self.norm2 = ( if self.use_ada_layer_norm:
AdaLayerNorm(dim, num_embeds_ada_norm) self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm elif self.use_ada_layer_norm_continuous:
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.norm2 = AdaLayerNormContinuous(
) dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention( self.attn2 = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None, cross_attention_dim=cross_attention_dim if not double_self_attention else None,
...@@ -210,20 +236,32 @@ class BasicTransformerBlock(nn.Module): ...@@ -210,20 +236,32 @@ class BasicTransformerBlock(nn.Module):
dropout=dropout, dropout=dropout,
bias=attention_bias, bias=attention_bias,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none ) # is self-attn if encoder_hidden_states is none
else: else:
self.norm2 = None self.norm2 = None
self.attn2 = None self.attn2 = None
# 3. Feed-forward # 3. Feed-forward
if not self.use_ada_layer_norm_single: if self.use_ada_layer_norm_continuous:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward( self.ff = FeedForward(
dim, dim,
dropout=dropout, dropout=dropout,
activation_fn=activation_fn, activation_fn=activation_fn,
final_dropout=final_dropout, final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
) )
# 4. Fuser # 4. Fuser
...@@ -252,6 +290,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -252,6 +290,7 @@ class BasicTransformerBlock(nn.Module):
timestep: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks. # Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention # 0. Self-Attention
...@@ -265,6 +304,8 @@ class BasicTransformerBlock(nn.Module): ...@@ -265,6 +304,8 @@ class BasicTransformerBlock(nn.Module):
) )
elif self.use_layer_norm: elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.use_ada_layer_norm_single: elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
...@@ -314,6 +355,8 @@ class BasicTransformerBlock(nn.Module): ...@@ -314,6 +355,8 @@ class BasicTransformerBlock(nn.Module):
# For PixArt norm2 isn't applied here: # For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states norm_hidden_states = hidden_states
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else: else:
raise ValueError("Incorrect norm") raise ValueError("Incorrect norm")
...@@ -329,7 +372,9 @@ class BasicTransformerBlock(nn.Module): ...@@ -329,7 +372,9 @@ class BasicTransformerBlock(nn.Module):
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states
# 4. Feed-forward # 4. Feed-forward
if not self.use_ada_layer_norm_single: if self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states) norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero: if self.use_ada_layer_norm_zero:
...@@ -490,6 +535,78 @@ class TemporalBasicTransformerBlock(nn.Module): ...@@ -490,6 +535,78 @@ class TemporalBasicTransformerBlock(nn.Module):
return hidden_states return hidden_states
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
class FeedForward(nn.Module): class FeedForward(nn.Module):
r""" r"""
A feed-forward layer. A feed-forward layer.
...@@ -512,10 +629,12 @@ class FeedForward(nn.Module): ...@@ -512,10 +629,12 @@ class FeedForward(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
activation_fn: str = "geglu", activation_fn: str = "geglu",
final_dropout: bool = False, final_dropout: bool = False,
inner_dim=None,
bias: bool = True, bias: bool = True,
): ):
super().__init__() super().__init__()
inner_dim = int(dim * mult) if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
......
...@@ -77,6 +77,7 @@ class Encoder(nn.Module): ...@@ -77,6 +77,7 @@ class Encoder(nn.Module):
norm_num_groups: int = 32, norm_num_groups: int = 32,
act_fn: str = "silu", act_fn: str = "silu",
double_z: bool = True, double_z: bool = True,
mid_block_add_attention=True,
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -124,6 +125,7 @@ class Encoder(nn.Module): ...@@ -124,6 +125,7 @@ class Encoder(nn.Module):
attention_head_dim=block_out_channels[-1], attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
temb_channels=None, temb_channels=None,
add_attention=mid_block_add_attention,
) )
# out # out
...@@ -213,6 +215,7 @@ class Decoder(nn.Module): ...@@ -213,6 +215,7 @@ class Decoder(nn.Module):
norm_num_groups: int = 32, norm_num_groups: int = 32,
act_fn: str = "silu", act_fn: str = "silu",
norm_type: str = "group", # group, spatial norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -240,6 +243,7 @@ class Decoder(nn.Module): ...@@ -240,6 +243,7 @@ class Decoder(nn.Module):
attention_head_dim=block_out_channels[-1], attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
temb_channels=temb_channels, temb_channels=temb_channels,
add_attention=mid_block_add_attention,
) )
# up # up
......
...@@ -20,6 +20,7 @@ import torch.nn.functional as F ...@@ -20,6 +20,7 @@ import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv from .lora import LoRACompatibleConv
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native from .upsampling import upfirdn2d_native
...@@ -89,6 +90,11 @@ class Downsample2D(nn.Module): ...@@ -89,6 +90,11 @@ class Downsample2D(nn.Module):
out_channels: Optional[int] = None, out_channels: Optional[int] = None,
padding: int = 1, padding: int = 1,
name: str = "conv", name: str = "conv",
kernel_size=3,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -99,8 +105,19 @@ class Downsample2D(nn.Module): ...@@ -99,8 +105,19 @@ class Downsample2D(nn.Module):
self.name = name self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv: if use_conv:
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
...@@ -117,6 +134,9 @@ class Downsample2D(nn.Module): ...@@ -117,6 +134,9 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
......
...@@ -197,11 +197,12 @@ class TimestepEmbedding(nn.Module): ...@@ -197,11 +197,12 @@ class TimestepEmbedding(nn.Module):
out_dim: int = None, out_dim: int = None,
post_act_fn: Optional[str] = None, post_act_fn: Optional[str] = None,
cond_proj_dim=None, cond_proj_dim=None,
sample_proj_bias=True,
): ):
super().__init__() super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = linear_cls(in_channels, time_embed_dim) self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None: if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
...@@ -214,7 +215,7 @@ class TimestepEmbedding(nn.Module): ...@@ -214,7 +215,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim time_embed_dim_out = out_dim
else: else:
time_embed_dim_out = time_embed_dim time_embed_dim_out = time_embed_dim
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out) self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None: if post_act_fn is None:
self.post_act = None self.post_act = None
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numbers
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import is_torch_version
from .activations import get_activation from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
...@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module): ...@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
x = F.group_norm(x, self.num_groups, eps=self.eps) x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
return x return x
class AdaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * nx) + self.beta + x
...@@ -20,6 +20,7 @@ import torch.nn.functional as F ...@@ -20,6 +20,7 @@ import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleConv from .lora import LoRACompatibleConv
from .normalization import RMSNorm
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
...@@ -95,6 +96,13 @@ class Upsample2D(nn.Module): ...@@ -95,6 +96,13 @@ class Upsample2D(nn.Module):
use_conv_transpose: bool = False, use_conv_transpose: bool = False,
out_channels: Optional[int] = None, out_channels: Optional[int] = None,
name: str = "conv", name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
interpolate=True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -102,13 +110,29 @@ class Upsample2D(nn.Module): ...@@ -102,13 +110,29 @@ class Upsample2D(nn.Module):
self.use_conv = use_conv self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
conv = None conv = None
if use_conv_transpose: if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) if kernel_size is None:
kernel_size = 4
conv = nn.ConvTranspose2d(
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
)
elif use_conv: elif use_conv:
conv = conv_cls(self.channels, self.out_channels, 3, padding=1) if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
...@@ -124,6 +148,9 @@ class Upsample2D(nn.Module): ...@@ -124,6 +148,9 @@ class Upsample2D(nn.Module):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if self.use_conv_transpose: if self.use_conv_transpose:
return self.conv(hidden_states) return self.conv(hidden_states)
...@@ -140,10 +167,11 @@ class Upsample2D(nn.Module): ...@@ -140,10 +167,11 @@ class Upsample2D(nn.Module):
# if `output_size` is passed we force the interpolation output # if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2` # size and do not make use of `scale_factor=2`
if output_size is None: if self.interpolate:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") if output_size is None:
else: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16 # If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, get_timestep_embedding
from .modeling_utils import ModelMixin
from .normalization import GlobalResponseNorm, RMSNorm
from .resnet import Downsample2D, Upsample2D
class UVit2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# global config
hidden_size: int = 1024,
use_bias: bool = False,
hidden_dropout: float = 0.0,
# conditioning dimensions
cond_embed_dim: int = 768,
micro_cond_encode_dim: int = 256,
micro_cond_embed_dim: int = 1280,
encoder_hidden_size: int = 768,
# num tokens
vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded
codebook_size: int = 8192,
# `UVit2DConvEmbed`
in_channels: int = 768,
block_out_channels: int = 768,
num_res_blocks: int = 3,
downsample: bool = False,
upsample: bool = False,
block_num_heads: int = 12,
# `TransformerLayer`
num_hidden_layers: int = 22,
num_attention_heads: int = 16,
# `Attention`
attention_dropout: float = 0.0,
# `FeedForward`
intermediate_size: int = 2816,
# `Norm`
layer_norm_eps: float = 1e-6,
ln_elementwise_affine: bool = True,
sample_size: int = 64,
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias)
self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
self.embed = UVit2DConvEmbed(
in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
)
self.cond_embed = TimestepEmbedding(
micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias
)
self.down_block = UVitBlock(
block_out_channels,
num_res_blocks,
hidden_size,
hidden_dropout,
ln_elementwise_affine,
layer_norm_eps,
use_bias,
block_num_heads,
attention_dropout,
downsample,
False,
)
self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine)
self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias)
self.transformer_layers = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_dim=hidden_size // num_attention_heads,
dropout=hidden_dropout,
cross_attention_dim=hidden_size,
attention_bias=use_bias,
norm_type="ada_norm_continuous",
ada_norm_continous_conditioning_embedding_dim=hidden_size,
norm_elementwise_affine=ln_elementwise_affine,
norm_eps=layer_norm_eps,
ada_norm_bias=use_bias,
ff_inner_dim=intermediate_size,
ff_bias=use_bias,
attention_out_bias=use_bias,
)
for _ in range(num_hidden_layers)
]
)
self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias)
self.up_block = UVitBlock(
block_out_channels,
num_res_blocks,
hidden_size,
hidden_dropout,
ln_elementwise_affine,
layer_norm_eps,
use_bias,
block_num_heads,
attention_dropout,
downsample=False,
upsample=upsample,
)
self.mlm_layer = ConvMlmLayer(
block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size
)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
pass
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
micro_cond_embeds = get_timestep_embedding(
micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1))
pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1)
pooled_text_emb = pooled_text_emb.to(dtype=self.dtype)
pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype)
hidden_states = self.embed(input_ids)
hidden_states = self.down_block(
hidden_states,
pooled_text_emb=pooled_text_emb,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
batch_size, channels, height, width = hidden_states.shape
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
hidden_states = self.project_to_hidden_norm(hidden_states)
hidden_states = self.project_to_hidden(hidden_states)
for layer in self.transformer_layers:
if self.training and self.gradient_checkpointing:
def layer_(*args):
return checkpoint(layer, *args)
else:
layer_ = layer
hidden_states = layer_(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
)
hidden_states = self.project_from_hidden_norm(hidden_states)
hidden_states = self.project_from_hidden(hidden_states)
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
hidden_states = self.up_block(
hidden_states,
pooled_text_emb=pooled_text_emb,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
logits = self.mlm_layer(hidden_states)
return logits
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
class UVit2DConvEmbed(nn.Module):
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
super().__init__()
self.embeddings = nn.Embedding(vocab_size, in_channels)
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
def forward(self, input_ids):
embeddings = self.embeddings(input_ids)
embeddings = self.layer_norm(embeddings)
embeddings = embeddings.permute(0, 3, 1, 2)
embeddings = self.conv(embeddings)
return embeddings
class UVitBlock(nn.Module):
def __init__(
self,
channels,
num_res_blocks: int,
hidden_size,
hidden_dropout,
ln_elementwise_affine,
layer_norm_eps,
use_bias,
block_num_heads,
attention_dropout,
downsample: bool,
upsample: bool,
):
super().__init__()
if downsample:
self.downsample = Downsample2D(
channels,
use_conv=True,
padding=0,
name="Conv2d_0",
kernel_size=2,
norm_type="rms_norm",
eps=layer_norm_eps,
elementwise_affine=ln_elementwise_affine,
bias=use_bias,
)
else:
self.downsample = None
self.res_blocks = nn.ModuleList(
[
ConvNextBlock(
channels,
layer_norm_eps,
ln_elementwise_affine,
use_bias,
hidden_dropout,
hidden_size,
)
for i in range(num_res_blocks)
]
)
self.attention_blocks = nn.ModuleList(
[
SkipFFTransformerBlock(
channels,
block_num_heads,
channels // block_num_heads,
hidden_size,
use_bias,
attention_dropout,
channels,
attention_bias=use_bias,
attention_out_bias=use_bias,
)
for _ in range(num_res_blocks)
]
)
if upsample:
self.upsample = Upsample2D(
channels,
use_conv_transpose=True,
kernel_size=2,
padding=0,
name="conv",
norm_type="rms_norm",
eps=layer_norm_eps,
elementwise_affine=ln_elementwise_affine,
bias=use_bias,
interpolate=False,
)
else:
self.upsample = None
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs):
if self.downsample is not None:
x = self.downsample(x)
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
x = res_block(x, pooled_text_emb)
batch_size, channels, height, width = x.shape
x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
x = attention_block(
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs
)
x = x.permute(0, 2, 1).view(batch_size, channels, height, width)
if self.upsample is not None:
x = self.upsample(x)
return x
class ConvNextBlock(nn.Module):
def __init__(
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
):
super().__init__()
self.depthwise = nn.Conv2d(
channels,
channels,
kernel_size=3,
padding=1,
groups=channels,
bias=use_bias,
)
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
self.channelwise_act = nn.GELU()
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
self.channelwise_dropout = nn.Dropout(hidden_dropout)
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
def forward(self, x, cond_embeds):
x_res = x
x = self.depthwise(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.channelwise_linear_1(x)
x = self.channelwise_act(x)
x = self.channelwise_norm(x)
x = self.channelwise_linear_2(x)
x = self.channelwise_dropout(x)
x = x.permute(0, 3, 1, 2)
x = x + x_res
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
return x
class ConvMlmLayer(nn.Module):
def __init__(
self,
block_out_channels: int,
in_channels: int,
use_bias: bool,
ln_elementwise_affine: bool,
layer_norm_eps: float,
codebook_size: int,
):
super().__init__()
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
def forward(self, hidden_states):
hidden_states = self.conv1(hidden_states)
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
logits = self.conv2(hidden_states)
return logits
...@@ -88,6 +88,9 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -88,6 +88,9 @@ class VQModel(ModelMixin, ConfigMixin):
vq_embed_dim: Optional[int] = None, vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215, scaling_factor: float = 0.18215,
norm_type: str = "group", # group, spatial norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
lookup_from_codebook=False,
force_upcast=False,
): ):
super().__init__() super().__init__()
...@@ -101,6 +104,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -101,6 +104,7 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
double_z=False, double_z=False,
mid_block_add_attention=mid_block_add_attention,
) )
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
...@@ -119,6 +123,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -119,6 +123,7 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
norm_type=norm_type, norm_type=norm_type,
mid_block_add_attention=mid_block_add_attention,
) )
@apply_forward_hook @apply_forward_hook
...@@ -133,11 +138,13 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -133,11 +138,13 @@ class VQModel(ModelMixin, ConfigMixin):
@apply_forward_hook @apply_forward_hook
def decode( def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer # also go through quantization layer
if not force_not_quantize: if not force_not_quantize:
quant, _, _ = self.quantize(h) quant, _, _ = self.quantize(h)
elif self.config.lookup_from_codebook:
quant = self.quantize.get_codebook_entry(h, shape)
else: else:
quant = h quant = h
quant2 = self.post_quant_conv(quant) quant2 = self.post_quant_conv(quant)
......
...@@ -108,6 +108,7 @@ else: ...@@ -108,6 +108,7 @@ else:
"VersatileDiffusionTextToImagePipeline", "VersatileDiffusionTextToImagePipeline",
] ]
) )
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = ["AnimateDiffPipeline"] _import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [ _import_structure["audioldm2"] = [
...@@ -342,6 +343,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -342,6 +343,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * from ..utils.dummy_torch_and_transformers_objects import *
else: else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import AnimateDiffPipeline from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline from .audioldm import AudioLDMPipeline
from .audioldm2 import ( from .audioldm2 import (
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
)
_dummy_objects.update(
{
"AmusedPipeline": AmusedPipeline,
"AmusedImg2ImgPipeline": AmusedImg2ImgPipeline,
"AmusedInpaintPipeline": AmusedInpaintPipeline,
}
)
else:
_import_structure["pipeline_amused"] = ["AmusedPipeline"]
_import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"]
_import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AmusedPipeline,
)
else:
from .pipeline_amused import AmusedPipeline
from .pipeline_amused_img2img import AmusedImg2ImgPipeline
from .pipeline_amused_inpaint import AmusedInpaintPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AmusedPipeline
>>> pipe = AmusedPipeline.from_pretrained(
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
```
"""
class AmusedPipeline(DiffusionPipeline):
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
text_encoder: CLIPTextModelWithProjection
transformer: UVit2DModel
scheduler: AmusedScheduler
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
def __init__(
self,
vqvae: VQModel,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
transformer: UVit2DModel,
scheduler: AmusedScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[List[str], str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 12,
guidance_scale: float = 10.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.IntTensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
output_type="pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
micro_conditioning_aesthetic_score: int = 6,
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
):
"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 16):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 10.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.IntTensor`, *optional*):
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
gneration. If not provided, the starting latents will be completely masked.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
pooled and projected final hidden states.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
Analogous to `encoder_hidden_states` for the positive prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
Examples:
Returns:
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
`tuple` is returned where the first element is a list with the generated images.
"""
if (prompt_embeds is not None and encoder_hidden_states is None) or (
prompt_embeds is None and encoder_hidden_states is not None
):
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
):
raise ValueError(
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
)
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
if isinstance(prompt, str):
prompt = [prompt]
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = batch_size * num_images_per_prompt
if height is None:
height = self.transformer.config.sample_size * self.vae_scale_factor
if width is None:
width = self.transformer.config.sample_size * self.vae_scale_factor
if prompt_embeds is None:
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.text_embeds
encoder_hidden_states = outputs.hidden_states[-2]
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
if guidance_scale > 1.0:
if negative_prompt_embeds is None:
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
input_ids = self.tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
negative_prompt_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
# Note that the micro conditionings _do_ flip the order of width, height for the original size
# and the crop coordinates. This is how it was done in the original code base
micro_conds = torch.tensor(
[
width,
height,
micro_conditioning_crop_coord[0],
micro_conditioning_crop_coord[1],
micro_conditioning_aesthetic_score,
],
device=self._execution_device,
dtype=encoder_hidden_states.dtype,
)
micro_conds = micro_conds.unsqueeze(0)
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
latents = torch.full(
shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
)
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, timestep in enumerate(self.scheduler.timesteps):
if guidance_scale > 1.0:
model_input = torch.cat([latents] * 2)
else:
model_input = latents
model_output = self.transformer(
model_input,
micro_conds=micro_conds,
pooled_text_emb=prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
if guidance_scale > 1.0:
uncond_logits, cond_logits = model_output.chunk(2)
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
latents = self.scheduler.step(
model_output=model_output,
timestep=timestep,
sample=latents,
generator=generator,
).prev_sample
if i == len(self.scheduler.timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if output_type == "latent":
output = latents
else:
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
if needs_upcasting:
self.vqvae.float()
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels,
),
).sample.clip(0, 1)
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
self.maybe_free_model_hooks()
if not return_dict:
return (output,)
return ImagePipelineOutput(output)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AmusedImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> pipe = AmusedImg2ImgPipeline.from_pretrained(
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "winter mountains"
>>> input_image = (
... load_image(
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg"
... )
... .resize((512, 512))
... .convert("RGB")
... )
>>> image = pipe(prompt, input_image).images[0]
```
"""
class AmusedImg2ImgPipeline(DiffusionPipeline):
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
text_encoder: CLIPTextModelWithProjection
transformer: UVit2DModel
scheduler: AmusedScheduler
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
# off the meta device. There should be a way to fix this instead of just not offloading it
_exclude_from_cpu_offload = ["vqvae"]
def __init__(
self,
vqvae: VQModel,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
transformer: UVit2DModel,
scheduler: AmusedScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[List[str], str]] = None,
image: PipelineImageInput = None,
strength: float = 0.5,
num_inference_steps: int = 12,
guidance_scale: float = 10.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[torch.Generator] = None,
prompt_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
output_type="pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
micro_conditioning_aesthetic_score: int = 6,
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
):
"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.5):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 16):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 10.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
pooled and projected final hidden states.
encoder_hidden_states (`torch.FloatTensor`, *optional*):
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
Analogous to `encoder_hidden_states` for the positive prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
Examples:
Returns:
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
`tuple` is returned where the first element is a list with the generated images.
"""
if (prompt_embeds is not None and encoder_hidden_states is None) or (
prompt_embeds is None and encoder_hidden_states is not None
):
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
):
raise ValueError(
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
)
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
if isinstance(prompt, str):
prompt = [prompt]
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
batch_size = batch_size * num_images_per_prompt
if prompt_embeds is None:
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.text_embeds
encoder_hidden_states = outputs.hidden_states[-2]
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
if guidance_scale > 1.0:
if negative_prompt_embeds is None:
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
input_ids = self.tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
negative_prompt_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
image = self.image_processor.preprocess(image)
height, width = image.shape[-2:]
# Note that the micro conditionings _do_ flip the order of width, height for the original size
# and the crop coordinates. This is how it was done in the original code base
micro_conds = torch.tensor(
[
width,
height,
micro_conditioning_crop_coord[0],
micro_conditioning_crop_coord[1],
micro_conditioning_aesthetic_score,
],
device=self._execution_device,
dtype=encoder_hidden_states.dtype,
)
micro_conds = micro_conds.unsqueeze(0)
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
if needs_upcasting:
self.vqvae.float()
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
latents_bsz, channels, latents_height, latents_width = latents.shape
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
latents = self.scheduler.add_noise(
latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator
)
latents = latents.repeat(num_images_per_prompt, 1, 1)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
timestep = self.scheduler.timesteps[i]
if guidance_scale > 1.0:
model_input = torch.cat([latents] * 2)
else:
model_input = latents
model_output = self.transformer(
model_input,
micro_conds=micro_conds,
pooled_text_emb=prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
)
if guidance_scale > 1.0:
uncond_logits, cond_logits = model_output.chunk(2)
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
latents = self.scheduler.step(
model_output=model_output,
timestep=timestep,
sample=latents,
generator=generator,
).prev_sample
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
if output_type == "latent":
output = latents
else:
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels,
),
).sample.clip(0, 1)
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
self.maybe_free_model_hooks()
if not return_dict:
return (output,)
return ImagePipelineOutput(output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment