"src/vscode:/vscode.git/clone" did not exist on "5dd35580f7918faa1de551cd80a0ce90a143c434"
Unverified Commit 365e8461 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SDXL DreamBooth LoRA] add support for text encoder fine-tuning (#4097)



* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fed12376
...@@ -164,6 +164,17 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output ...@@ -164,6 +164,17 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output
|---|---| |---|---|
| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) | | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |
### Training with text encoder(s)
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
## Notes ## Notes
In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗 In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import gc import gc
import hashlib import hashlib
import itertools
import logging import logging
import math import math
import os import os
...@@ -45,11 +46,11 @@ import diffusers ...@@ -45,11 +46,11 @@ import diffusers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDPMScheduler, DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
...@@ -63,12 +64,7 @@ logger = get_logger(__name__) ...@@ -63,12 +64,7 @@ logger = get_logger(__name__)
def save_model_card( def save_model_card(
repo_id: str, repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
images=None,
base_model=str,
train_text_encoder=False,
prompt=str,
repo_folder=None,
): ):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
...@@ -96,6 +92,8 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p ...@@ -96,6 +92,8 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
{img_str} {img_str}
LoRA for the text encoder was enabled: {train_text_encoder}. LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card) f.write(yaml + model_card)
...@@ -130,6 +128,12 @@ def parse_args(input_args=None): ...@@ -130,6 +128,12 @@ def parse_args(input_args=None):
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--pretrained_vae_model_name_or_path",
type=str,
default=None,
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
...@@ -420,38 +424,25 @@ def parse_args(input_args=None): ...@@ -420,38 +424,25 @@ def parse_args(input_args=None):
if args.class_prompt is not None: if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.") warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
if args.train_text_encoder and args.pre_compute_text_embeddings:
raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
return args return args
class DreamBoothDataset(Dataset): class DreamBoothDataset(Dataset):
""" """
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts. It pre-processes the images.
""" """
def __init__( def __init__(
self, self,
instance_data_root, instance_data_root,
instance_prompt,
class_data_root=None, class_data_root=None,
class_prompt=None,
class_num=None, class_num=None,
size=1024, size=1024,
center_crop=False, center_crop=False,
instance_prompt_hidden_states=None,
class_prompt_hidden_states=None,
instance_unet_added_conditions=None,
class_unet_added_conditions=None,
): ):
self.size = size self.size = size
self.center_crop = center_crop self.center_crop = center_crop
self.instance_prompt_hidden_states = instance_prompt_hidden_states
self.class_prompt_hidden_states = class_prompt_hidden_states
self.instance_unet_added_conditions = instance_unet_added_conditions
self.class_unet_added_conditions = class_unet_added_conditions
self.instance_data_root = Path(instance_data_root) self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists(): if not self.instance_data_root.exists():
...@@ -459,7 +450,6 @@ class DreamBoothDataset(Dataset): ...@@ -459,7 +450,6 @@ class DreamBoothDataset(Dataset):
self.instance_images_path = list(Path(instance_data_root).iterdir()) self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path) self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images self._length = self.num_instance_images
if class_data_root is not None: if class_data_root is not None:
...@@ -471,7 +461,6 @@ class DreamBoothDataset(Dataset): ...@@ -471,7 +461,6 @@ class DreamBoothDataset(Dataset):
else: else:
self.num_class_images = len(self.class_images_path) self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images) self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else: else:
self.class_data_root = None self.class_data_root = None
...@@ -496,9 +485,6 @@ class DreamBoothDataset(Dataset): ...@@ -496,9 +485,6 @@ class DreamBoothDataset(Dataset):
instance_image = instance_image.convert("RGB") instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image) example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.instance_prompt_hidden_states
example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions
if self.class_data_root: if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image) class_image = exif_transpose(class_image)
...@@ -506,49 +492,22 @@ class DreamBoothDataset(Dataset): ...@@ -506,49 +492,22 @@ class DreamBoothDataset(Dataset):
if not class_image.mode == "RGB": if not class_image.mode == "RGB":
class_image = class_image.convert("RGB") class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image) example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.class_prompt_hidden_states
example["class_added_cond_kwargs"] = self.class_unet_added_conditions
return example return example
def collate_fn(examples, with_prior_preservation=False): def collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples] pixel_values = [example["instance_images"] for example in examples]
add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples]
add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples]
if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]
# Concat class and instance examples for prior preservation. # Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes. # We do this to avoid doing two forward passes.
if with_prior_preservation: if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples] pixel_values += [example["class_images"] for example in examples]
add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples]
add_time_ids += [example["class_added_cond_kwargs"]["time_ids"] for example in examples]
if has_attention_mask:
attention_mask += [example["class_attention_mask"] for example in examples]
pixel_values = torch.stack(pixel_values) pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0) batch = {"pixel_values": pixel_values}
add_text_embeds = torch.cat(add_text_embeds, dim=0)
add_time_ids = torch.cat(add_time_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
}
if has_attention_mask:
batch["attention_mask"] = attention_mask
return batch return batch
...@@ -569,27 +528,29 @@ class PromptDataset(Dataset): ...@@ -569,27 +528,29 @@ class PromptDataset(Dataset):
return example return example
def tokenize_prompt(tokenizer, prompt):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt): def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
prompt_embeds_list = [] prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders): for i, text_encoder in enumerate(text_encoders):
text_inputs = tokenizer( if tokenizers is not None:
prompt, tokenizer = tokenizers[i]
padding="max_length", text_input_ids = tokenize_prompt(tokenizer, prompt)
max_length=tokenizer.model_max_length, else:
truncation=True, assert text_input_ids_list is not None
return_tensors="pt", text_input_ids = text_input_ids_list[i]
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device), text_input_ids.to(text_encoder.device),
...@@ -641,9 +602,6 @@ def main(args): ...@@ -641,9 +602,6 @@ def main(args):
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb import wandb
if args.train_text_encoder:
raise NotImplementedError("Text encoder training not yet supported.")
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -677,7 +635,7 @@ def main(args): ...@@ -677,7 +635,7 @@ def main(args):
torch_dtype = torch.float16 torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16": elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
safety_checker=None, safety_checker=None,
...@@ -742,7 +700,14 @@ def main(args): ...@@ -742,7 +700,14 @@ def main(args):
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) vae_path = (
args.pretrained_model_name_or_path
if args.pretrained_vae_model_name_or_path is None
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
...@@ -764,7 +729,10 @@ def main(args): ...@@ -764,7 +729,10 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses. # The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=torch.float32) if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype)
...@@ -804,42 +772,66 @@ def main(args): ...@@ -804,42 +772,66 @@ def main(args):
unet_lora_parameters.extend(module.parameters()) unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs) unet.set_attn_processor(unet_lora_attn_procs)
# unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers # there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers # or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models: for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
LoraLoaderMixin.save_lora_weights( StableDiffusionXLPipeline.save_lora_weights(
output_dir, output_dir,
unet_lora_layers=unet_lora_layers_to_save, unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=None, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
text_encoder_ = None text_encoder_one_ = None
text_encoder_two_ = None
while len(models) > 0: while len(models) > 0:
model = models.pop() model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_ = model unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
) )
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
...@@ -869,7 +861,11 @@ def main(args): ...@@ -869,7 +861,11 @@ def main(args):
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
# Optimizer creation # Optimizer creation
params_to_optimize = unet_lora_parameters params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
...@@ -878,62 +874,81 @@ def main(args): ...@@ -878,62 +874,81 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL # Computes additional embeddings/ids required by the SDXL UNet.
# UNet as the model is already big and it uses two text encoders. # regular text emebddings (when `train_text_encoder` is not True)
# TODO: when we add support for text encoder training, will reivist. # pooled text embeddings
tokenizers = [tokenizer_one, tokenizer_two] # time ids
text_encoders = [text_encoder_one, text_encoder_two]
# Here, we compute not just the text embeddings but also the additional embeddings def compute_time_ids():
# needed for the SD XL UNet to operate. # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
def compute_embeddings(prompt, text_encoders, tokenizers):
original_size = (args.resolution, args.resolution) original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
if not args.train_text_encoder:
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
# Handle instance prompt.
instance_time_ids = compute_time_ids()
if not args.train_text_encoder:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
with torch.no_grad(): # Handle class prompt for prior-preservation.
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
add_text_embeds = pooled_prompt_embeds
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
prompt_embeds = prompt_embeds.to(accelerator.device)
add_text_embeds = add_text_embeds.to(accelerator.device)
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
return prompt_embeds, unet_added_cond_kwargs
instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
class_prompt_hidden_states, class_unet_added_conditions = None, None
if args.with_prior_preservation: if args.with_prior_preservation:
class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings( class_time_ids = compute_time_ids()
args.class_prompt, text_encoders, tokenizers if not args.train_text_encoder:
) class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
)
del tokenizers, text_encoders # Clear the memory here.
if not args.train_text_encoder:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
gc.collect() # Pack the statically computed variables appropriately. This is so that we don't
torch.cuda.empty_cache() # have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Dataset and DataLoaders creation: # Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
class_num=args.num_class_images, class_num=args.num_class_images,
size=args.resolution, size=args.resolution,
center_crop=args.center_crop, center_crop=args.center_crop,
instance_prompt_hidden_states=instance_prompt_hidden_states,
class_prompt_hidden_states=class_prompt_hidden_states,
instance_unet_added_conditions=instance_unet_added_conditions,
class_unet_added_conditions=class_unet_added_conditions,
) )
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
...@@ -954,16 +969,21 @@ def main(args): ...@@ -954,16 +969,21 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles, num_cycles=args.lr_num_cycles,
power=args.lr_power, power=args.lr_power,
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( if args.train_text_encoder:
unet, optimizer, train_dataloader, lr_scheduler unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
) unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
...@@ -1022,6 +1042,9 @@ def main(args): ...@@ -1022,6 +1042,9 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step # Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
...@@ -1030,12 +1053,16 @@ def main(args): ...@@ -1030,12 +1053,16 @@ def main(args):
continue continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# pixel_values = batch["pixel_values"].to(dtype=weight_dtype) if args.pretrained_vae_model_name_or_path is None:
pixel_values = batch["pixel_values"]
else:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(batch["pixel_values"]).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor model_input = model_input * vae.config.scaling_factor
model_input = model_input.to(weight_dtype) if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(model_input) noise = torch.randn_like(model_input)
...@@ -1051,9 +1078,30 @@ def main(args): ...@@ -1051,9 +1078,30 @@ def main(args):
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual # Predict the noise residual
model_pred = unet( if not args.train_text_encoder:
noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"] unet_added_conditions = {
).sample "time_ids": add_time_ids.repeat(bsz, 1),
"text_embeds": unet_add_text_embeds.repeat(bsz, 1),
}
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds.repeat(bsz, 1, 1),
added_cond_kwargs=unet_added_conditions,
).sample
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)})
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
...@@ -1081,7 +1129,11 @@ def main(args): ...@@ -1081,7 +1129,11 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = unet_lora_parameters params_to_clip = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
...@@ -1132,8 +1184,22 @@ def main(args): ...@@ -1132,8 +1184,22 @@ def main(args):
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one)
if args.train_text_encoder
else text_encoder_one,
text_encoder_2=accelerator.unwrap_model(text_encoder_two)
if args.train_text_encoder
else text_encoder_two,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -1161,9 +1227,11 @@ def main(args): ...@@ -1161,9 +1227,11 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
images = [ with torch.cuda.amp.autocast():
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) images = [
] pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
...@@ -1189,16 +1257,32 @@ def main(args): ...@@ -1189,16 +1257,32 @@ def main(args):
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet) unet_lora_layers = unet_attn_processors_state_dict(unet)
LoraLoaderMixin.save_lora_weights( if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers, unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None, text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
...@@ -1250,6 +1334,7 @@ def main(args): ...@@ -1250,6 +1334,7 @@ def main(args):
train_text_encoder=args.train_text_encoder, train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt, prompt=args.instance_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
) )
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
......
...@@ -385,6 +385,42 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -385,6 +385,42 @@ class ExamplesTestsAccelerate(unittest.TestCase):
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet) self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
keys = lora_state_dict.keys()
starts_with_unet = all(
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
)
self.assertTrue(starts_with_unet)
def test_custom_diffusion(self): def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
test_args = f""" test_args = f"""
......
...@@ -59,7 +59,7 @@ if is_safetensors_available(): ...@@ -59,7 +59,7 @@ if is_safetensors_available():
import safetensors import safetensors
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module): ...@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
def text_encoder_attn_modules(text_encoder): def text_encoder_attn_modules(text_encoder):
attn_modules = [] attn_modules = []
if isinstance(text_encoder, CLIPTextModel): if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers): for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn" name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn mod = layer.self_attn
...@@ -1016,18 +1016,20 @@ class LoraLoaderMixin: ...@@ -1016,18 +1016,20 @@ class LoraLoaderMixin:
warnings.warn(warn_message) warnings.warn(warn_message)
@classmethod @classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters: Parameters:
state_dict (`dict`): state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key shoult be prefixed with an A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers. additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`): network_alpha (`float`):
See `LoRALinearLayer` for more details. See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`): text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into. The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`): lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer. lora layer.
...@@ -1037,14 +1039,16 @@ class LoraLoaderMixin: ...@@ -1037,14 +1039,16 @@ class LoraLoaderMixin:
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): prefix = cls.text_encoder_name if prefix is None else prefix
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)] text_encoder_keys = [k for k in keys if k.startswith(prefix)]
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
} }
if len(text_encoder_lora_state_dict) > 0: if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {cls.text_encoder_name}.") logger.info(f"Loading {prefix}.")
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention. # Convert from the old naming convention to the new naming convention.
...@@ -1184,23 +1188,10 @@ class LoraLoaderMixin: ...@@ -1184,23 +1188,10 @@ class LoraLoaderMixin:
replace `torch.save` with another method. Can be configured with the environment variable replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
""" """
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
# Create a flat dictionary. # Create a flat dictionary.
state_dict = {} state_dict = {}
# Populate the dictionary.
if unet_lora_layers is not None: if unet_lora_layers is not None:
weights = ( weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
...@@ -1222,6 +1213,38 @@ class LoraLoaderMixin: ...@@ -1222,6 +1213,38 @@ class LoraLoaderMixin:
state_dict.update(text_encoder_lora_state_dict) state_dict.update(text_encoder_lora_state_dict)
# Save the model # Save the model
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None: if weight_name is None:
if safe_serialization: if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE weight_name = LORA_WEIGHT_NAME_SAFE
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -841,3 +842,66 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -841,3 +842,66 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
return (image,) return (image,)
return StableDiffusionXLPipelineOutput(images=image) return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -21,9 +21,16 @@ import torch ...@@ -21,9 +21,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import (
AutoencoderKL,
DDIMScheduler,
EulerDiscreteScheduler,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
Attention, Attention,
...@@ -399,7 +406,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -399,7 +406,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
) )
self.assertIsInstance(module.processor, attn_proc_class) self.assertIsInstance(module.processor, attn_proc_class)
def test_unload_lora(self): def test_unload_lora_sd(self):
pipeline_components, lora_components = self.get_dummy_components() pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionPipeline(**pipeline_components) sd_pipe = StableDiffusionPipeline(**pipeline_components)
...@@ -503,6 +510,175 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -503,6 +510,175 @@ class LoraLoaderMixinTests(unittest.TestCase):
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
class SDXLLoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)
text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2)
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"text_encoder_one_lora_layers": text_encoder_one_lora_layers,
"text_encoder_two_lora_layers": text_encoder_two_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_unload_lora_sdxl(self):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Unload LoRA parameters.
sd_pipe.unload_lora_weights()
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
assert not np.allclose(
orig_image_slice, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert not np.allclose(
orig_image_slice_two, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert np.allclose(
orig_image_slice, orig_image_slice_two, atol=1e-3
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
@slow @slow
@require_torch_gpu @require_torch_gpu
class LoraIntegrationTests(unittest.TestCase): class LoraIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment