"...text-generation-inference.git" did not exist on "b1f9044d6cf082423a517cf9a6aa6e5ebd34e1c2"
Unverified Commit d03c9099 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Wuerstchen] text to image training script (#5052)



* initial script

* formatting

* prior trainer wip

* add efficient_net_encoder

* add CLIPTextModel

* add prior ema support

* optimizer

* fix typo

* add dataloader

* prompt_embeds and image_embeds

* intial training loop

* fix output_dir

* fix add_noise

* accelerator check

* make effnet_transforms dynamic

* fix training loop

* add validation logging

* use loaded text_encoder

* use PreTrainedTokenizerFast

* load weigth from pickle

* save_model_card

* remove unused file

* fix typos

* save prior pipeilne in its own folder

* fix imports

* fix pipe_t2i

* scale image_embeds

* remove snr_gamma

* format

* initial lora prior training

* log_validation and save

* initial gradient working

* remove save/load hooks

* set set_attn_processor on prior_prior

* add lora script

* typos

* use LoraLoaderMixin for prior pipeline

* fix usage

* make fix-copies

* yse repo_id

* write_lora_layers is a staitcmethod

* use defualts

* fix defaults

* undo

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add graident checkpoint support to prior

* gradient_checkpointing

* formatting

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/README.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* use default unet and text_encoder

* fix test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 93df5bb6
# Würstchen text-to-image fine-tuning
## Running locally with PyTorch
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd into the example folder and run
```bash
cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
```bash
huggingface-cli login
```
## Prior training
You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
<br>
<!-- accelerate_snippet_start -->
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image_prior.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=768 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--dataloader_num_workers=4 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--checkpoints_total_limit=3 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--validation_prompts="A robot pokemon, 4k photo" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-model"
```
<!-- accelerate_snippet_end -->
## Training with LoRA
Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
### Prior Training
First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image_prior_lora.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=768 \
--train_batch_size=8 \
--num_train_epochs=100 --checkpointing_steps=5000 \
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
--seed=42 \
--rank=4 \
--validation_prompt="cute dragon creature" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-lora"
```
import torch.nn as nn
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
class EfficientNetEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
super().__init__()
if effnet == "efficientnet_v2_s":
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
else:
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
self.mapper = nn.Sequential(
nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)
def forward(self, x):
return self.mapper(self.backbone(x))
accelerate>=0.16.0
torchvision
transformers>=4.25.1
wandb
huggingface-cli
bitsandbytes
deepspeed
This diff is collapsed.
......@@ -1208,7 +1208,7 @@ class LoraLoaderMixin:
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=self.unet,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
......@@ -1216,7 +1216,9 @@ class LoraLoaderMixin:
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
text_encoder=getattr(self, self.text_encoder_name)
if not hasattr(self, "text_encoder")
else self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
......@@ -1577,7 +1579,7 @@ class LoraLoaderMixin:
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
......@@ -1961,7 +1963,7 @@ class LoraLoaderMixin:
@classmethod
def save_lora_weights(
self,
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
......@@ -2001,7 +2003,7 @@ class LoraLoaderMixin:
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
)
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict)
if text_encoder_lora_layers is not None:
......@@ -2012,12 +2014,12 @@ class LoraLoaderMixin:
)
text_encoder_lora_state_dict = {
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
}
state_dict.update(text_encoder_lora_state_dict)
# Save the model
self.write_lora_layers(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
......@@ -2026,6 +2028,7 @@ class LoraLoaderMixin:
safe_serialization=safe_serialization,
)
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
......@@ -3248,7 +3251,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
@classmethod
def save_lora_weights(
self,
cls,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
......@@ -3299,7 +3302,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
......
......@@ -14,16 +14,29 @@
# limitations under the License.
import math
from typing import Dict, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
class WuerstchenPrior(ModelMixin, ConfigMixin):
class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
unet_name = "prior"
_supports_gradient_checkpointing = True
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
......@@ -45,6 +58,90 @@ class WuerstchenPrior(ModelMixin, ConfigMixin):
nn.Conv2d(c, c_in * 2, kernel_size=1),
)
self.gradient_checkpointing = False
self.set_default_attn_processor()
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
......@@ -61,12 +158,42 @@ class WuerstchenPrior(ModelMixin, ConfigMixin):
x = self.projection(x)
c_embed = self.cond_mapper(c)
r_embed = self.gen_r_embedding(r)
for block in self.blocks:
if isinstance(block, AttnBlock):
x = block(x, c_embed)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, c_embed, use_reentrant=False
)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
else:
x = block(x)
for block in self.blocks:
if isinstance(block, AttnBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
elif isinstance(block, TimestepBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
else:
for block in self.blocks:
if isinstance(block, AttnBlock):
x = block(x, c_embed)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
a, b = self.out(x).chunk(2, dim=1)
return (x_in - a) / ((1 - b).abs() + 1e-5)
......@@ -20,6 +20,7 @@ import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import LoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import (
BaseOutput,
......@@ -65,7 +66,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.FloatTensor, np.ndarray]
class WuerstchenPriorPipeline(DiffusionPipeline):
class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
"""
Pipeline for generating image prior for Wuerstchen.
......@@ -90,6 +91,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
Default resolution for multiple images generated.
"""
unet_name = "prior"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->prior"
def __init__(
......
......@@ -211,24 +211,15 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
device = original_samples.device
dtype = original_samples.dtype
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
)
noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
return noisy_samples.to(dtype=dtype)
def __len__(self):
return self.config.num_train_timesteps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment