Unverified Commit c6e08ecd authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Sd3 Dreambooth LoRA] Add text encoder training for the clip encoders (#8630)



* add clip text-encoder training

* no dora

* text encoder traing fixes

* text encoder traing fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* add text_encoder layers to save_lora

* style

* fix imports

* style

* fix text encoder

* review changes

* review changes

* review changes

* minor change

* add lora tag

* style

* add readme notes

* add tests for clip encoders

* style

* typo

* fixes

* style

* Update tests/lora/test_lora_layers_sd3.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/README_sd3.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* minor readme change

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 4ad7a1f5
...@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \ ...@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub --push_to_hub
``` ```
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export OUTPUT_DIR="trained-sd3-lora"
accelerate launch train_dreambooth_lora_sd3.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--dataset_name="Norod78/Yarn-art-style" \
--instance_prompt="a photo of TOK yarn art dog" \
--resolution=1024 \
--train_batch_size=1 \
--train_text_encoder\
--gradient_accumulation_steps=1 \
--optimizer="prodigy"\
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1500 \
--rank=32 \
--seed="0" \
--push_to_hub
```
## Other notes ## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
\ No newline at end of file
...@@ -54,6 +54,7 @@ from diffusers import ( ...@@ -54,6 +54,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import ( from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params, cast_training_params,
compute_density_for_timestep_sampling, compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3, compute_loss_weighting_for_sd3,
...@@ -80,6 +81,7 @@ def save_model_card( ...@@ -80,6 +81,7 @@ def save_model_card(
repo_id: str, repo_id: str,
images=None, images=None,
base_model: str = None, base_model: str = None,
train_text_encoder=False,
instance_prompt=None, instance_prompt=None,
validation_prompt=None, validation_prompt=None,
repo_folder=None, repo_folder=None,
...@@ -103,6 +105,8 @@ These are {repo_id} DreamBooth weights for {base_model}. ...@@ -103,6 +105,8 @@ These are {repo_id} DreamBooth weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/). The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
## Trigger words ## Trigger words
You should use {instance_prompt} to trigger the image generation. You should use {instance_prompt} to trigger the image generation.
...@@ -113,7 +117,7 @@ You should use {instance_prompt} to trigger the image generation. ...@@ -113,7 +117,7 @@ You should use {instance_prompt} to trigger the image generation.
## License ## License
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`. Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
""" """
model_card = load_or_create_model_card( model_card = load_or_create_model_card(
repo_id_or_path=repo_id, repo_id_or_path=repo_id,
...@@ -128,6 +132,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co ...@@ -128,6 +132,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co
"text-to-image", "text-to-image",
"diffusers-training", "diffusers-training",
"diffusers", "diffusers",
"lora",
"sd3", "sd3",
"sd3-diffusers", "sd3-diffusers",
"template:sd-lora", "template:sd-lora",
...@@ -381,6 +386,12 @@ def parse_args(input_args=None): ...@@ -381,6 +386,12 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="whether to randomly flip images horizontally", help="whether to randomly flip images horizontally",
) )
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder (clip text encoders only). If set, the text encoder should be float32 precision.",
)
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
) )
...@@ -856,10 +867,12 @@ def _encode_prompt_with_t5( ...@@ -856,10 +867,12 @@ def _encode_prompt_with_t5(
prompt=None, prompt=None,
num_images_per_prompt=1, num_images_per_prompt=1,
device=None, device=None,
text_input_ids=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -869,6 +882,10 @@ def _encode_prompt_with_t5( ...@@ -869,6 +882,10 @@ def _encode_prompt_with_t5(
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype dtype = text_encoder.dtype
...@@ -888,11 +905,13 @@ def _encode_prompt_with_clip( ...@@ -888,11 +905,13 @@ def _encode_prompt_with_clip(
tokenizer, tokenizer,
prompt: str, prompt: str,
device=None, device=None,
text_input_ids=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer( text_inputs = tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -902,6 +921,10 @@ def _encode_prompt_with_clip( ...@@ -902,6 +921,10 @@ def _encode_prompt_with_clip(
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0] pooled_prompt_embeds = prompt_embeds[0]
...@@ -923,6 +946,7 @@ def encode_prompt( ...@@ -923,6 +946,7 @@ def encode_prompt(
max_sequence_length, max_sequence_length,
device=None, device=None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
...@@ -931,13 +955,14 @@ def encode_prompt( ...@@ -931,13 +955,14 @@ def encode_prompt(
clip_prompt_embeds_list = [] clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = []
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
prompt=prompt, prompt=prompt,
device=device if device is not None else text_encoder.device, device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i],
) )
clip_prompt_embeds_list.append(prompt_embeds) clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
...@@ -951,6 +976,7 @@ def encode_prompt( ...@@ -951,6 +976,7 @@ def encode_prompt(
max_sequence_length, max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[:-1],
device=device if device is not None else text_encoders[-1].device, device=device if device is not None else text_encoders[-1].device,
) )
...@@ -1145,6 +1171,9 @@ def main(args): ...@@ -1145,6 +1171,9 @@ def main(args):
if args.gradient_checkpointing: if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing() transformer.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig( transformer_lora_config = LoraConfig(
...@@ -1155,6 +1184,16 @@ def main(args): ...@@ -1155,6 +1184,16 @@ def main(args):
) )
transformer.add_adapter(transformer_lora_config) transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
def unwrap_model(model): def unwrap_model(model):
model = accelerator.unwrap_model(model) model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model model = model._orig_mod if is_compiled_module(model) else model
...@@ -1164,10 +1203,16 @@ def main(args): ...@@ -1164,10 +1203,16 @@ def main(args):
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
transformer_lora_layers_to_save = None transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models: for model in models:
if isinstance(model, type(unwrap_model(transformer))): if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1175,17 +1220,26 @@ def main(args): ...@@ -1175,17 +1220,26 @@ def main(args):
weights.pop() weights.pop()
StableDiffusion3Pipeline.save_lora_weights( StableDiffusion3Pipeline.save_lora_weights(
output_dir, transformer_lora_layers=transformer_lora_layers_to_save output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
transformer_ = None transformer_ = None
text_encoder_one_ = None
text_encoder_two_ = None
while len(models) > 0: while len(models) > 0:
model = models.pop() model = models.pop()
if isinstance(model, type(unwrap_model(transformer))): if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
...@@ -1204,12 +1258,21 @@ def main(args): ...@@ -1204,12 +1258,21 @@ def main(args):
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. " f" {unexpected_keys}. "
) )
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
)
# Make sure the trainable params are in float32. This is again needed since the base models # Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details: # are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
models = [transformer_] models = [transformer_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models) cast_training_params(models)
...@@ -1229,13 +1292,36 @@ def main(args): ...@@ -1229,13 +1292,36 @@ def main(args):
# Make sure the trainable params are in float32. # Make sure the trainable params are in float32.
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
models = [transformer] models = [transformer]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32) cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
# Optimization parameters # Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
transformer_parameters_with_lr,
text_lora_parameters_one_with_lr,
text_lora_parameters_two_with_lr,
]
else:
params_to_optimize = [transformer_parameters_with_lr] params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation # Optimizer creation
...@@ -1317,6 +1403,7 @@ def main(args): ...@@ -1317,6 +1403,7 @@ def main(args):
num_workers=args.dataloader_num_workers, num_workers=args.dataloader_num_workers,
) )
if not args.train_text_encoder:
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
...@@ -1329,19 +1416,20 @@ def main(args): ...@@ -1329,19 +1416,20 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
if not train_dataset.custom_instance_prompts: if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers args.instance_prompt, text_encoders, tokenizers
) )
# Handle class prompt for prior-preservation. # Handle class prompt for prior-preservation.
if args.with_prior_preservation: if args.with_prior_preservation:
if not args.train_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers args.class_prompt, text_encoders, tokenizers
) )
# Clear the memory here # Clear the memory here
if not train_dataset.custom_instance_prompts: if not args.train_text_encoder and train_dataset.custom_instance_prompts:
del tokenizers, text_encoders del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three del text_encoder_one, text_encoder_two, text_encoder_three
...@@ -1354,6 +1442,7 @@ def main(args): ...@@ -1354,6 +1442,7 @@ def main(args):
# have to pass them to the dataloader. # have to pass them to the dataloader.
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states prompt_embeds = instance_prompt_hidden_states
pooled_prompt_embeds = instance_pooled_prompt_embeds pooled_prompt_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation: if args.with_prior_preservation:
...@@ -1390,6 +1479,19 @@ def main(args): ...@@ -1390,6 +1479,19 @@ def main(args):
) )
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
transformer,
text_encoder_one,
text_encoder_two,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler transformer, optimizer, train_dataloader, lr_scheduler
) )
...@@ -1470,6 +1572,13 @@ def main(args): ...@@ -1470,6 +1572,13 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
transformer.train() transformer.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer] models_to_accumulate = [transformer]
...@@ -1479,7 +1588,30 @@ def main(args): ...@@ -1479,7 +1588,30 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts: if train_dataset.custom_instance_prompts:
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers) if not args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts)
tokens_two = tokenize_prompt(tokenizer_two, prompts)
tokens_three = tokenize_prompt(tokenizer_three, prompts)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
else:
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
# Convert images to latent space # Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
...@@ -1553,7 +1685,11 @@ def main(args): ...@@ -1553,7 +1685,11 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = transformer_lora_parameters params_to_clip = itertools.chain(
transformer_lora_parameters,
text_lora_parameters_one,
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
...@@ -1600,10 +1736,18 @@ def main(args): ...@@ -1600,10 +1736,18 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
if not args.train_text_encoder:
# create pipeline # create pipeline
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
) )
else:
text_encoder_three = text_encoder_cls_three.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_3",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusion3Pipeline.from_pretrained( pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
...@@ -1634,8 +1778,20 @@ def main(args): ...@@ -1634,8 +1778,20 @@ def main(args):
transformer = transformer.to(torch.float32) transformer = transformer.to(torch.float32)
transformer_lora_layers = get_peft_model_state_dict(transformer) transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
StableDiffusion3Pipeline.save_lora_weights( StableDiffusion3Pipeline.save_lora_weights(
save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
# Final inference # Final inference
...@@ -1669,6 +1825,7 @@ def main(args): ...@@ -1669,6 +1825,7 @@ def main(args):
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt, validation_prompt=args.validation_prompt,
train_text_encoder=args.train_text_encoder,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
upload_folder( upload_folder(
......
...@@ -1601,6 +1601,8 @@ class SD3LoraLoaderMixin: ...@@ -1601,6 +1601,8 @@ class SD3LoraLoaderMixin:
cls, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, torch.nn.Module] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
...@@ -1632,12 +1634,20 @@ class SD3LoraLoaderMixin: ...@@ -1632,12 +1634,20 @@ class SD3LoraLoaderMixin:
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict return layers_state_dict
if not transformer_lora_layers: if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name)) state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
......
...@@ -137,6 +137,15 @@ class SD3LoRATests(unittest.TestCase): ...@@ -137,6 +137,15 @@ class SD3LoRATests(unittest.TestCase):
) )
return lora_config return lora_config
def get_lora_config_for_text_encoders(self):
text_lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
return text_lora_config
def test_simple_inference_with_transformer_lora_save_load(self): def test_simple_inference_with_transformer_lora_save_load(self):
components = self.get_dummy_components() components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer() transformer_config = self.get_lora_config_for_transformer()
...@@ -173,6 +182,55 @@ class SD3LoRATests(unittest.TestCase): ...@@ -173,6 +182,55 @@ class SD3LoRATests(unittest.TestCase):
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_clip_encoders_lora_save_load(self):
components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
pipe.transformer.add_adapter(transformer_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2.")
inputs = self.get_dummy_inputs(torch_device)
images_lora = pipe(**inputs).images
with tempfile.TemporaryDirectory() as tmpdirname:
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
text_encoder_one_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_two_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
transformer_lora_layers=transformer_state_dict,
text_encoder_lora_layers=text_encoder_one_state_dict,
text_encoder_2_lora_layers=text_encoder_two_state_dict,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
inputs = self.get_dummy_inputs(torch_device)
images_lora_from_pretrained = pipe(**inputs).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_transformer_lora_and_scale(self): def test_simple_inference_with_transformer_lora_and_scale(self):
components = self.get_dummy_components() components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer() transformer_lora_config = self.get_lora_config_for_transformer()
...@@ -206,6 +264,44 @@ class SD3LoRATests(unittest.TestCase): ...@@ -206,6 +264,44 @@ class SD3LoRATests(unittest.TestCase):
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
) )
def test_simple_inference_with_clip_encoders_lora_and_scale(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
inputs = self.get_dummy_inputs(torch_device)
output_lora = pipe(**inputs).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_transformer_fused(self): def test_simple_inference_with_transformer_fused(self):
components = self.get_dummy_components() components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer() transformer_lora_config = self.get_lora_config_for_transformer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment