Unverified Commit f4dddaf5 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

[textual_inversion] Fix resuming state when using gradient checkpointing (#2072)

* Fix resuming state when using gradient checkpointing.

Also, allow --resume_from_checkpoint to be used when the checkpoint does
not yet exist (a normal training run will be started).

* style
parent 7d8b4f7f
...@@ -597,7 +597,7 @@ def main(): ...@@ -597,7 +597,7 @@ def main():
text_encoder, optimizer, train_dataloader, lr_scheduler text_encoder, optimizer, train_dataloader, lr_scheduler
) )
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the unet and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
...@@ -643,14 +643,21 @@ def main(): ...@@ -643,14 +643,21 @@ def main():
dirs = os.listdir(args.output_dir) dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] path = dirs[-1] if len(dirs) > 0 else None
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) if path is None:
global_step = int(path.split("-")[1]) accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
resume_global_step = global_step * args.gradient_accumulation_steps )
first_epoch = resume_global_step // num_update_steps_per_epoch args.resume_from_checkpoint = None
resume_step = resume_global_step % num_update_steps_per_epoch else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment