Unverified Commit 62608a91 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[train_text_to_image] allow using non-ema weights for training (#1834)



* allow using non-ema weights for training

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* address more review comment

* reorganise a few lines

* always pad text to max_length to match original training

* ifx collate_fn

* remove unused code

* don't prepare ema_unet, don't register lr scheduler

* style

* assert => ValueError

* add allow_tf32

* set log level

* fix comment
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent e4fe9413
import argparse import argparse
import copy
import logging import logging
import math import math
import os import os
...@@ -11,6 +12,9 @@ import torch ...@@ -11,6 +12,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
import datasets
import diffusers
import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
...@@ -28,7 +32,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -28,7 +32,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0") check_min_version("0.10.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__, log_level="INFO")
def parse_args(): def parse_args():
...@@ -171,7 +175,25 @@ def parse_args(): ...@@ -171,7 +175,25 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--non_ema_revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
" remote repository specified with --pretrained_model_name_or_path."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
...@@ -247,6 +269,10 @@ def parse_args(): ...@@ -247,6 +269,10 @@ def parse_args():
if args.dataset_name is None and args.train_data_dir is None: if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.") raise ValueError("Need either a dataset name or a training folder.")
# default to using the same revision for the non-ema model if not specified
if args.non_ema_revision is None:
args.non_ema_revision = args.revision
return args return args
...@@ -275,6 +301,8 @@ class EMAModel: ...@@ -275,6 +301,8 @@ class EMAModel:
parameters = list(parameters) parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters] self.shadow_params = [p.clone().detach() for p in parameters]
self.collected_params = None
self.decay = decay self.decay = decay
self.optimization_step = 0 self.optimization_step = 0
...@@ -322,6 +350,55 @@ class EMAModel: ...@@ -322,6 +350,55 @@ class EMAModel:
for p in self.shadow_params for p in self.shadow_params
] ]
def state_dict(self) -> dict:
r"""
Returns the state of the ExponentialMovingAverage as a dict.
This method is used by accelerate during checkpointing to save the ema state dict.
"""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
"decay": self.decay,
"optimization_step": self.optimization_step,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params,
}
def load_state_dict(self, state_dict: dict) -> None:
r"""
Loads the ExponentialMovingAverage state.
This method is used by accelerate during checkpointing to save the ema state dict.
Args:
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)
self.decay = state_dict["decay"]
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.optimization_step = state_dict["optimization_step"]
if not isinstance(self.optimization_step, int):
raise ValueError("Invalid optimization_step")
self.shadow_params = state_dict["shadow_params"]
if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")
self.collected_params = state_dict["collected_params"]
if self.collected_params is not None:
if not isinstance(self.collected_params, list):
raise ValueError("collected_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
raise ValueError("collected_params must all be Tensors")
if len(self.collected_params) != len(self.shadow_params):
raise ValueError("collected_params and shadow_params must have the same length")
def main(): def main():
args = parse_args() args = parse_args()
...@@ -339,6 +416,15 @@ def main(): ...@@ -339,6 +416,15 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO, level=logging.INFO,
) )
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now. # If passed along, set the training seed now.
if args.seed is not None: if args.seed is not None:
...@@ -361,39 +447,44 @@ def main(): ...@@ -361,39 +447,44 @@ def main():
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion # Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
) )
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
subfolder="unet",
revision=args.revision,
) )
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
ema_unet = EMAModel(ema_unet.parameters())
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
else: else:
raise ValueError("xformers is not available. Make sure it is installed correctly") raise ValueError("xformers is not available. Make sure it is installed correctly")
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
...@@ -419,7 +510,6 @@ def main(): ...@@ -419,7 +510,6 @@ def main():
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
...@@ -482,9 +572,10 @@ def main(): ...@@ -482,9 +572,10 @@ def main():
raise ValueError( raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings." f"Caption column `{caption_column}` should contain either strings or lists of strings."
) )
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) inputs = tokenizer(
input_ids = inputs.input_ids captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
return input_ids )
return inputs.input_ids
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
[ [
...@@ -500,7 +591,6 @@ def main(): ...@@ -500,7 +591,6 @@ def main():
images = [image.convert("RGB") for image in examples[image_column]] images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images] examples["pixel_values"] = [train_transforms(image) for image in images]
examples["input_ids"] = tokenize_captions(examples) examples["input_ids"] = tokenize_captions(examples)
return examples return examples
with accelerator.main_process_first(): with accelerator.main_process_first():
...@@ -512,13 +602,8 @@ def main(): ...@@ -512,13 +602,8 @@ def main():
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example["input_ids"] for example in examples] input_ids = torch.stack([example["input_ids"] for example in examples])
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") return {"pixel_values": pixel_values, "input_ids": input_ids}
return {
"pixel_values": pixel_values,
"input_ids": padded_tokens.input_ids,
"attention_mask": padded_tokens.attention_mask,
}
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
...@@ -541,23 +626,22 @@ def main(): ...@@ -541,23 +626,22 @@ def main():
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
accelerator.register_for_checkpointing(lr_scheduler) if args.use_ema:
accelerator.register_for_checkpointing(ema_unet)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu and cast to weight_dtype
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
# Create EMA for the unet.
if args.use_ema: if args.use_ema:
ema_unet = EMAModel(unet.parameters()) ema_unet.to(accelerator.device)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment