Unverified Commit 9ea7052f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[textual inversion] add gradient checkpointing and small fixes. (#1848)


Co-authored-by: default avatarHenrik Forstén <henrik.forsten@gmail.com>

* update TI script

* make flake happy

* fix typo
parent 03bf877b
import argparse
import itertools
import math
import os
import random
......@@ -147,6 +146,11 @@ def parse_args():
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
......@@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"
def freeze_params(params):
for param in params:
param.requires_grad = False
def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
......@@ -460,6 +459,10 @@ def main():
revision=args.revision,
)
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
......@@ -474,15 +477,12 @@ def main():
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
# Freeze vae and unet
freeze_params(vae.parameters())
freeze_params(unet.parameters())
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.encoder.parameters(),
text_encoder.text_model.final_layer_norm.parameters(),
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
if args.scale_lr:
args.learning_rate = (
......@@ -541,9 +541,10 @@ def main():
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
# Keep vae and unet in eval model as we don't train these
vae.eval()
unet.eval()
# Keep unet in train mode if we are using gradient checkpointing to save memory.
# The dropout is 0 so it doesn't matter if we are in eval or train mode.
if args.gradient_checkpointing:
unet.train()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
......@@ -609,12 +610,11 @@ def main():
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
......@@ -669,8 +669,7 @@ def main():
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment