"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7bbbfbfd18ed9f5f6ce02bf194382a27150dd4c4"
Unverified Commit e01d6cf2 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Dreambooth: save / restore training state (#1668)



* Dreambooth: save / restore training state.

* make style

* Rename vars for clarity.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Remove unused import
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 244e16a7
import argparse import argparse
import hashlib import hashlib
import inspect
import itertools import itertools
import math import math
import os import os
...@@ -150,7 +149,24 @@ def parse_args(input_args=None): ...@@ -150,7 +149,24 @@ def parse_args(input_args=None):
default=None, default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument( parser.add_argument(
"--gradient_accumulation_steps", "--gradient_accumulation_steps",
type=int, type=int,
...@@ -579,6 +595,7 @@ def main(args): ...@@ -579,6 +595,7 @@ def main(args):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
accelerator.register_for_checkpointing(lr_scheduler)
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
...@@ -616,16 +633,41 @@ def main(args): ...@@ -616,16 +633,41 @@ def main(args):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = resume_global_step // num_update_steps_per_epoch
resume_step = resume_global_step % num_update_steps_per_epoch
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps") progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
...@@ -689,25 +731,11 @@ def main(args): ...@@ -689,25 +731,11 @@ def main(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.save_steps == 0: if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process: if accelerator.is_main_process:
# When 'keep_fp32_wrapper' is `False` (the default), then the models are
# unwrapped and the mixed precision hooks are removed, so training crashes
# when the unwrapped models are used for further training.
# This is only supported in newer versions of `accelerate`.
# TODO(Pedro, Suraj): Remove `accepts_keep_fp32_wrapper` when forcing newer accelerate versions
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment