Unverified Commit c2717317 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Adapt example scripts to use PEFT (#5388)



* adapt example scripts to use PEFT

* Update examples/text_to_image/train_text_to_image_lora.py

* fix

* add for SDXL

* oops

* make sure to install peft

* fix

* fix

* fix dreambooth and lora

* more fixes

* add peft to requirements.txt

* fix

* final fix

* add peft version in requirements

* remove comment

* change variable names

* add few lines in readme

* add to reqs

* style

* fix issues

* fix lora dreambooth xl tests

* init_lora_weights to gaussian and add out proj where missing

* ammend requirements.

* ammend requirements.txt

* add correct peft versions

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent bf7f9b49
...@@ -113,6 +113,7 @@ jobs: ...@@ -113,6 +113,7 @@ jobs:
- name: Run example PyTorch CPU tests - name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }} if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: | run: |
python -m pip install peft
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \ --make-reports=tests_${{ matrix.config.report }} \
examples examples
......
...@@ -44,6 +44,7 @@ write_basic_config() ...@@ -44,6 +44,7 @@ write_basic_config()
``` ```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example ### Dog toy example
......
...@@ -47,6 +47,7 @@ write_basic_config() ...@@ -47,6 +47,7 @@ write_basic_config()
``` ```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example ### Dog toy example
......
...@@ -4,3 +4,4 @@ transformers>=4.25.1 ...@@ -4,3 +4,4 @@ transformers>=4.25.1
ftfy ftfy
tensorboard tensorboard
Jinja2 Jinja2
peft==0.7.0
\ No newline at end of file
...@@ -4,3 +4,4 @@ transformers>=4.25.1 ...@@ -4,3 +4,4 @@ transformers>=4.25.1
ftfy ftfy
tensorboard tensorboard
Jinja2 Jinja2
peft==0.7.0
\ No newline at end of file
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import argparse import argparse
import copy import copy
import gc import gc
import itertools
import logging import logging
import math import math
import os import os
...@@ -35,6 +34,8 @@ from accelerate.utils import ProjectConfiguration, set_seed ...@@ -35,6 +34,8 @@ from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib from huggingface_hub.utils import insecure_hashlib
from packaging import version from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -52,14 +53,7 @@ from diffusers import ( ...@@ -52,14 +53,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -864,79 +858,19 @@ def main(args): ...@@ -864,79 +858,19 @@ def main(args):
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes unet_lora_config = LoraConfig(
# The sizes of the attention layers consist only of two different variables: r=args.rank,
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. init_lora_weights="gaussian",
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
# Let's first see how many attention processors we will have to set. unet.add_adapter(unet_lora_config)
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
# => 32 layers
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder: if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 text_lora_config = LoraConfig(
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank) r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)
text_encoder.add_adapter(text_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
...@@ -948,9 +882,9 @@ def main(args): ...@@ -948,9 +882,9 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_lora_state_dict(model) unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1010,11 +944,10 @@ def main(args): ...@@ -1010,11 +944,10 @@ def main(args):
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
# Optimizer creation # Optimizer creation
params_to_optimize = ( params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
itertools.chain(unet_lora_parameters, text_lora_parameters) if args.train_text_encoder:
if args.train_text_encoder params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
else unet_lora_parameters
)
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
...@@ -1257,12 +1190,7 @@ def main(args): ...@@ -1257,12 +1190,7 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = ( accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -1385,19 +1313,19 @@ def main(args): ...@@ -1385,19 +1313,19 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = unet_lora_state_dict(unet)
if text_encoder is not None and args.train_text_encoder: unet_lora_state_dict = get_peft_model_state_dict(unet)
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32) text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
else: else:
text_encoder_lora_layers = None text_encoder_state_dict = None
LoraLoaderMixin.save_lora_weights( LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers, unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_state_dict,
) )
# Final inference # Final inference
......
...@@ -34,6 +34,8 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration ...@@ -34,6 +34,8 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib from huggingface_hub.utils import insecure_hashlib
from packaging import version from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -50,9 +52,8 @@ from diffusers import ( ...@@ -50,9 +52,8 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -1009,54 +1010,19 @@ def main(args): ...@@ -1009,54 +1010,19 @@ def main(args):
text_encoder_two.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# Set correct lora layers unet_lora_config = LoraConfig(
unet_lora_parameters = [] r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
for attn_processor_name, attn_processor in unet.attn_processors.items(): )
# Parse the attention module. unet.add_adapter(unet_lora_config)
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder: if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 text_lora_config = LoraConfig(
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
) )
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
...@@ -1069,11 +1035,11 @@ def main(args): ...@@ -1069,11 +1035,11 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_lora_state_dict(model) unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1130,6 +1096,12 @@ def main(args): ...@@ -1130,6 +1096,12 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
) )
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
# Optimization parameters # Optimization parameters
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder: if args.train_text_encoder:
...@@ -1194,26 +1166,10 @@ def main(args): ...@@ -1194,26 +1166,10 @@ def main(args):
optimizer_class = prodigyopt.Prodigy optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
decouple=args.prodigy_decouple, decouple=args.prodigy_decouple,
...@@ -1659,13 +1615,13 @@ def main(args): ...@@ -1659,13 +1615,13 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = unet_lora_state_dict(unet) unet_lora_layers = get_peft_model_state_dict(unet)
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
else: else:
text_encoder_lora_layers = None text_encoder_lora_layers = None
text_encoder_2_lora_layers = None text_encoder_2_lora_layers = None
......
...@@ -32,6 +32,8 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e ...@@ -32,6 +32,8 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
accelerate config accelerate config
``` ```
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Pokemon example ### Pokemon example
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
......
...@@ -45,6 +45,7 @@ write_basic_config() ...@@ -45,6 +45,7 @@ write_basic_config()
``` ```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Training ### Training
......
...@@ -5,3 +5,4 @@ datasets ...@@ -5,3 +5,4 @@ datasets
ftfy ftfy
tensorboard tensorboard
Jinja2 Jinja2
peft==0.7.0
\ No newline at end of file
...@@ -5,3 +5,4 @@ ftfy ...@@ -5,3 +5,4 @@ ftfy
tensorboard tensorboard
Jinja2 Jinja2
datasets datasets
peft==0.7.0
\ No newline at end of file
...@@ -34,13 +34,14 @@ from accelerate.utils import ProjectConfiguration, set_seed ...@@ -34,13 +34,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
...@@ -479,62 +480,20 @@ def main(): ...@@ -479,62 +480,20 @@ def main():
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Freeze the unet parameters before adding adapters
for param in unet.parameters():
param.requires_grad_(False)
unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
)
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
# now we will add new LoRA weights to the attention layers unet.add_adapter(unet_lora_config)
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
...@@ -549,6 +508,8 @@ def main(): ...@@ -549,6 +508,8 @@ def main():
else: else:
raise ValueError("xformers is not available. Make sure it is installed correctly") raise ValueError("xformers is not available. Make sure it is installed correctly")
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32: if args.allow_tf32:
...@@ -573,7 +534,7 @@ def main(): ...@@ -573,7 +534,7 @@ def main():
optimizer_cls = torch.optim.AdamW optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls( optimizer = optimizer_cls(
unet_lora_parameters, lora_layers,
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -700,8 +661,8 @@ def main(): ...@@ -700,8 +661,8 @@ def main():
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
...@@ -833,7 +794,7 @@ def main(): ...@@ -833,7 +794,7 @@ def main():
# Backpropagate # Backpropagate
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = unet_lora_parameters params_to_clip = lora_layers
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
...@@ -870,6 +831,15 @@ def main(): ...@@ -870,6 +831,15 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
...@@ -926,7 +896,13 @@ def main(): ...@@ -926,7 +896,13 @@ def main():
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
unet_lora_state_dict = get_peft_model_state_dict(unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA.""" """Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
import argparse import argparse
import itertools
import logging import logging
import math import math
import os import os
...@@ -37,6 +36,8 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration ...@@ -37,6 +36,8 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import crop from torchvision.transforms.functional import crop
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -50,7 +51,6 @@ from diffusers import ( ...@@ -50,7 +51,6 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
...@@ -658,53 +658,20 @@ def main(args): ...@@ -658,53 +658,20 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# Set correct lora layers # Set correct lora layers
unet_lora_parameters = [] unet_lora_config = LoraConfig(
for attn_processor_name, attn_processor in unet.attn_processors.items(): r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
# Parse the attention module. )
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize. unet.add_adapter(unet_lora_config)
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder: if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( text_lora_config = LoraConfig(
text_encoder_one, dtype=torch.float32, rank=args.rank r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
) )
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
...@@ -717,11 +684,11 @@ def main(args): ...@@ -717,11 +684,11 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -792,11 +759,13 @@ def main(args): ...@@ -792,11 +759,13 @@ def main(args):
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
# Optimizer creation # Optimizer creation
params_to_optimize = ( params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) if args.train_text_encoder:
if args.train_text_encoder params_to_optimize = (
else unet_lora_parameters params_to_optimize
) + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
)
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
...@@ -1128,12 +1097,7 @@ def main(args): ...@@ -1128,12 +1097,7 @@ def main(args):
# Backpropagate # Backpropagate
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = ( accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -1229,20 +1193,21 @@ def main(args): ...@@ -1229,20 +1193,21 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet_lora_layers = unet_attn_processors_state_dict(unet) unet_lora_state_dict = get_peft_model_state_dict(unet)
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
else: else:
text_encoder_lora_layers = None text_encoder_lora_layers = None
text_encoder_2_lora_layers = None text_encoder_2_lora_layers = None
StableDiffusionXLPipeline.save_lora_weights( StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers, unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment