Unverified Commit d29d97b6 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[examples/advanced_diffusion_training] bug fixes and improvements for LoRA...


[examples/advanced_diffusion_training] bug fixes and improvements for LoRA Dreambooth SDXL advanced training script (#5935)

* imports and readme bug fixes

* bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16

* added pivotal tuning to readme

* mapping token identifier to new inserted token in validation prompt (if used)

* correct default value of --train_text_encoder_frac

* change default value of  --adam_weight_decay_text_encoder

* validation prompt generations when using pivotal tuning bug fix

* style fix

* textual inversion embeddings name change

* style fix

* bug fix - stopping text encoder optimization halfway

* readme - will include token abstraction and new inserted tokens when using pivotal tuning
- added type to --num_new_tokens_per_abstraction

* style fix

---------
Co-authored-by: default avatarLinoy Tsaban <linoy@huggingface.co>
parent 7d4a257c
...@@ -54,7 +54,7 @@ from diffusers import ( ...@@ -54,7 +54,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
...@@ -67,11 +67,46 @@ check_min_version("0.24.0.dev0") ...@@ -67,11 +67,46 @@ check_min_version("0.24.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
images=None, images=None,
base_model=str, base_model=str,
train_text_encoder=False, train_text_encoder=False,
train_text_encoder_ti=False,
token_abstraction_dict=None,
instance_prompt=str, instance_prompt=str,
validation_prompt=str, validation_prompt=str,
repo_folder=None, repo_folder=None,
...@@ -83,10 +118,23 @@ def save_model_card( ...@@ -83,10 +118,23 @@ def save_model_card(
img_str += f""" img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }' - text: '{validation_prompt if validation_prompt else ' ' }'
output: output:
url: >- url:
"image_{i}.png" "image_{i}.png"
""" """
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n"
)
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
trigger_str += f"""
to trigger concept {key}-> use {tokens} in your prompt \n
"""
yaml = f""" yaml = f"""
--- ---
tags: tags:
...@@ -96,9 +144,7 @@ tags: ...@@ -96,9 +144,7 @@ tags:
- diffusers - diffusers
- lora - lora
- template:sd-lora - template:sd-lora
widget:
{img_str} {img_str}
---
base_model: {base_model} base_model: {base_model}
instance_prompt: {instance_prompt} instance_prompt: {instance_prompt}
license: openrail++ license: openrail++
...@@ -112,14 +158,19 @@ license: openrail++ ...@@ -112,14 +158,19 @@ license: openrail++
## Model description ## Model description
These are {repo_id} LoRA adaption weights for {base_model}. ### These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/). The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}. LoRA for the text encoder was enabled: {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}. Special VAE used for training: {vae_path}.
## Trigger words ## Trigger words
You should use {instance_prompt} to trigger the image generation. {trigger_str}
## Download model ## Download model
...@@ -244,6 +295,7 @@ def parse_args(input_args=None): ...@@ -244,6 +295,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--num_new_tokens_per_abstraction", "--num_new_tokens_per_abstraction",
type=int,
default=2, default=2,
help="number of new tokens inserted to the tokenizers per token_abstraction value when " help="number of new tokens inserted to the tokenizers per token_abstraction value when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
...@@ -455,7 +507,7 @@ def parse_args(input_args=None): ...@@ -455,7 +507,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--train_text_encoder_frac", "--train_text_encoder_frac",
type=float, type=float,
default=0.5, default=1.0,
help=("The percentage of epochs to perform text encoder tuning"), help=("The percentage of epochs to perform text encoder tuning"),
) )
...@@ -488,7 +540,7 @@ def parse_args(input_args=None): ...@@ -488,7 +540,7 @@ def parse_args(input_args=None):
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument( parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
) )
parser.add_argument( parser.add_argument(
...@@ -679,12 +731,19 @@ class TokenEmbeddingsHandler: ...@@ -679,12 +731,19 @@ class TokenEmbeddingsHandler:
def save_embeddings(self, file_path: str): def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings." assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {} tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders): for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0] self.tokenizers[0]
), "Tokenizers should be the same." ), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
tensors[f"text_encoders_{idx}"] = new_token_embeddings
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
# Note: When loading with diffusers, any name can work - simply specify in inference
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
# tensors[f"text_encoders_{idx}"] = new_token_embeddings
save_file(tensors, file_path) save_file(tensors, file_path)
...@@ -696,19 +755,6 @@ class TokenEmbeddingsHandler: ...@@ -696,19 +755,6 @@ class TokenEmbeddingsHandler:
def device(self): def device(self):
return self.text_encoders[0].device return self.text_encoders[0].device
# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
# # Assuming new tokens are of the format <s_i>
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
# tokenizer.add_special_tokens(special_tokens_dict)
# text_encoder.resize_token_embeddings(len(tokenizer))
#
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
# text_encoder.text_model.embeddings.token_embedding.weight.data[
# self.train_ids
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
@torch.no_grad() @torch.no_grad()
def retract_embeddings(self): def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders): for idx, text_encoder in enumerate(self.text_encoders):
...@@ -730,15 +776,6 @@ class TokenEmbeddingsHandler: ...@@ -730,15 +776,6 @@ class TokenEmbeddingsHandler:
new_embeddings = new_embeddings * (off_ratio**0.1) new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
# def load_embeddings(self, file_path: str):
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
# for idx in range(len(self.text_encoders)):
# text_encoder = self.text_encoders[idx]
# tokenizer = self.tokenizers[idx]
#
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
class DreamBoothDataset(Dataset): class DreamBoothDataset(Dataset):
""" """
...@@ -1216,6 +1253,8 @@ def main(args): ...@@ -1216,6 +1253,8 @@ def main(args):
text_lora_parameters_one = [] text_lora_parameters_one = []
for name, param in text_encoder_one.named_parameters(): for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name: if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
text_lora_parameters_one.append(param) text_lora_parameters_one.append(param)
else: else:
...@@ -1223,6 +1262,8 @@ def main(args): ...@@ -1223,6 +1262,8 @@ def main(args):
text_lora_parameters_two = [] text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters(): for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name: if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
text_lora_parameters_two.append(param) text_lora_parameters_two.append(param)
else: else:
...@@ -1309,12 +1350,16 @@ def main(args): ...@@ -1309,12 +1350,16 @@ def main(args):
# different learning rate for text encoder and unet # different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = { text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one, "params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder, "weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
} }
text_lora_parameters_two_with_lr = { text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two, "params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder, "weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
} }
params_to_optimize = [ params_to_optimize = [
...@@ -1494,6 +1539,12 @@ def main(args): ...@@ -1494,6 +1539,12 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
if args.train_text_encoder_ti and args.validation_prompt:
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
...@@ -1593,27 +1644,10 @@ def main(args): ...@@ -1593,27 +1644,10 @@ def main(args):
if epoch == num_train_epochs_text_encoder: if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch) print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params # stopping optimization of text_encoder params
params_to_optimize = params_to_optimize[:1] # re setting the optimizer to optimize only on unet params
# reinitializing the optimizer to optimize only on unet params optimizer.param_groups[1]["lr"] = 0.0
if args.optimizer.lower() == "prodigy": optimizer.param_groups[2]["lr"] = 0.0
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
else: # AdamW or 8-bit-AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
else: else:
# still optimizng the text encoder # still optimizng the text encoder
text_encoder_one.train() text_encoder_one.train()
...@@ -1628,7 +1662,7 @@ def main(args): ...@@ -1628,7 +1662,7 @@ def main(args):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype) pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"] prompts = batch["prompts"]
print(prompts) # print(prompts)
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts: if train_dataset.custom_instance_prompts:
if freeze_text_encoder: if freeze_text_encoder:
...@@ -1801,7 +1835,7 @@ def main(args): ...@@ -1801,7 +1835,7 @@ def main(args):
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
# create pipeline # create pipeline
if not args.train_text_encoder: if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
) )
...@@ -1948,6 +1982,8 @@ def main(args): ...@@ -1948,6 +1982,8 @@ def main(args):
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder, train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
token_abstraction_dict=train_dataset.token_abstraction_dict,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt, validation_prompt=args.validation_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment