Unverified Commit fbe807bf authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[dreambooth] allow fine-tuning text encoder (#883)

* allow fine-tuning text encoder

* fix a few things

* update readme
parent a3efa433
...@@ -160,6 +160,39 @@ accelerate launch train_dreambooth.py \ ...@@ -160,6 +160,39 @@ accelerate launch train_dreambooth.py \
--mixed_precision=fp16 --mixed_precision=fp16
``` ```
### Fine-tune text encoder with the UNet.
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_text_encoder \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--use_8bit_adam
--gradient_checkpointing \
--learning_rate=2e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
```
## Inference ## Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
......
import argparse import argparse
import itertools
import math import math
import os import os
from pathlib import Path from pathlib import Path
...@@ -100,6 +101,7 @@ def parse_args(): ...@@ -100,6 +101,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
) )
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
parser.add_argument( parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
) )
...@@ -320,6 +322,15 @@ def main(): ...@@ -320,6 +322,15 @@ def main():
logging_dir=logging_dir, logging_dir=logging_dir,
) )
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
...@@ -385,8 +396,14 @@ def main(): ...@@ -385,8 +396,14 @@ def main():
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
...@@ -406,8 +423,11 @@ def main(): ...@@ -406,8 +423,11 @@ def main():
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class( optimizer = optimizer_class(
unet.parameters(), # only optimize unet params_to_optimize,
lr=args.learning_rate, lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
...@@ -467,9 +487,14 @@ def main(): ...@@ -467,9 +487,14 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
) )
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( if args.train_text_encoder:
unet, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
) unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32 weight_dtype = torch.float32
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
...@@ -480,8 +505,9 @@ def main(): ...@@ -480,8 +505,9 @@ def main():
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
...@@ -516,9 +542,8 @@ def main(): ...@@ -516,9 +542,8 @@ def main():
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
with torch.no_grad(): latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215
latents = latents * 0.18215
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
...@@ -532,8 +557,7 @@ def main(): ...@@ -532,8 +557,7 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning # Get the text embedding for conditioning
with torch.no_grad(): encoder_hidden_states = text_encoder(batch["input_ids"])[0]
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual # Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
...@@ -556,7 +580,12 @@ def main(): ...@@ -556,7 +580,12 @@ def main():
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -578,7 +607,9 @@ def main(): ...@@ -578,7 +607,9 @@ def main():
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet) args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment