Unverified Commit 9ea7052f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[textual inversion] add gradient checkpointing and small fixes. (#1848)


Co-authored-by: default avatarHenrik Forstén <henrik.forsten@gmail.com>

* update TI script

* make flake happy

* fix typo
parent 03bf877b
import argparse import argparse
import itertools
import math import math
import os import os
import random import random
...@@ -147,6 +146,11 @@ def parse_args(): ...@@ -147,6 +146,11 @@ def parse_args():
default=1, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.", help="Number of updates steps to accumulate before performing a backward/update pass.",
) )
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
type=float, type=float,
...@@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
def freeze_params(params):
for param in params:
param.requires_grad = False
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
...@@ -460,6 +459,10 @@ def main(): ...@@ -460,6 +459,10 @@ def main():
revision=args.revision, revision=args.revision,
) )
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
...@@ -474,15 +477,12 @@ def main(): ...@@ -474,15 +477,12 @@ def main():
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
# Freeze vae and unet # Freeze vae and unet
freeze_params(vae.parameters()) vae.requires_grad_(False)
freeze_params(unet.parameters()) unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain( text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
...@@ -541,9 +541,10 @@ def main(): ...@@ -541,9 +541,10 @@ def main():
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
# Keep vae and unet in eval model as we don't train these # Keep unet in train mode if we are using gradient checkpointing to save memory.
vae.eval() # The dropout is 0 so it doesn't matter if we are in eval or train mode.
unet.eval() if args.gradient_checkpointing:
unet.train()
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
...@@ -609,12 +610,11 @@ def main(): ...@@ -609,12 +610,11 @@ def main():
latents = latents * 0.18215 latents = latents * 0.18215
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype) noise = torch.randn_like(latents)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device timesteps = timesteps.long()
).long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
...@@ -669,8 +669,7 @@ def main(): ...@@ -669,8 +669,7 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds: if args.push_to_hub and args.only_save_embeds:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment