Unverified Commit d56825e4 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

fix: how print training resume logs. (#5117)



* fix: how print training resume logs.

* propagate changes to text-to-image scripts.

* propagate changes to instructpix2pix.

* propagate changes to dreambooth

* propagate changes to custom diffusion and instructpix2pix

* propagate changes to kandinsky

* propagate changes to textual inv.

* debug

* fix: checkpointing.

* debug

* debug

* debug

* back to the square

* debug

* debug

* change condition order.

* debug

* debug

* debug

* debug

* revert to original

* clean

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent cd1b8d7c
...@@ -1075,30 +1075,30 @@ def main(args): ...@@ -1075,30 +1075,30 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.modifier_token is not None: if args.modifier_token is not None:
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
......
...@@ -1178,30 +1178,30 @@ def main(args): ...@@ -1178,30 +1178,30 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype) pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
......
...@@ -1108,30 +1108,30 @@ def main(args): ...@@ -1108,30 +1108,30 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype) pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
......
...@@ -1048,18 +1048,25 @@ def main(args): ...@@ -1048,18 +1048,25 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
...@@ -1067,12 +1074,6 @@ def main(args): ...@@ -1067,12 +1074,6 @@ def main(args):
text_encoder_one.train() text_encoder_one.train()
text_encoder_two.train() text_encoder_two.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype) pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
......
...@@ -726,6 +726,9 @@ def main(): ...@@ -726,6 +726,9 @@ def main():
text_encoder_1.requires_grad_(False) text_encoder_1.requires_grad_(False)
text_encoder_2.requires_grad_(False) text_encoder_2.requires_grad_(False)
# Set UNet to trainable.
unet.train()
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt): def encode_prompt(text_encoders, tokenizers, prompt):
prompt_embeds_list = [] prompt_embeds_list = []
...@@ -933,29 +936,28 @@ def main(): ...@@ -933,29 +936,28 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# We want to learn the denoising process w.r.t the edited images which # We want to learn the denoising process w.r.t the edited images which
# are conditioned on the original image (which was edited) and the edit instruction. # are conditioned on the original image (which was edited) and the edit instruction.
......
...@@ -512,6 +512,9 @@ def main(): ...@@ -512,6 +512,9 @@ def main():
vae.requires_grad_(False) vae.requires_grad_(False)
image_encoder.requires_grad_(False) image_encoder.requires_grad_(False)
# Set unet to trainable.
unet.train()
# Create EMA for the unet. # Create EMA for the unet.
if args.use_ema: if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet") ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
...@@ -727,27 +730,28 @@ def main(): ...@@ -727,27 +730,28 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
images = batch["pixel_values"].to(weight_dtype) images = batch["pixel_values"].to(weight_dtype)
......
...@@ -579,29 +579,29 @@ def main(): ...@@ -579,29 +579,29 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
images = batch["pixel_values"].to(weight_dtype) images = batch["pixel_values"].to(weight_dtype)
......
...@@ -595,30 +595,33 @@ def main(): ...@@ -595,30 +595,33 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device) clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
clip_std = clip_std.to(weight_dtype).to(accelerator.device) clip_std = clip_std.to(weight_dtype).to(accelerator.device)
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
prior.train() prior.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(prior): with accelerator.accumulate(prior):
# Convert images to latent space # Convert images to latent space
text_input_ids, text_mask, clip_images = ( text_input_ids, text_mask, clip_images = (
......
...@@ -517,6 +517,9 @@ def main(): ...@@ -517,6 +517,9 @@ def main():
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False) image_encoder.requires_grad_(False)
# Set prior to trainable.
prior.train()
# Create EMA for the prior. # Create EMA for the prior.
if args.use_ema: if args.use_ema:
ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
...@@ -741,32 +744,31 @@ def main(): ...@@ -741,32 +744,31 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device) clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
clip_std = clip_std.to(weight_dtype).to(accelerator.device) clip_std = clip_std.to(weight_dtype).to(accelerator.device)
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
prior.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(prior): with accelerator.accumulate(prior):
# Convert images to latent space # Convert images to latent space
text_input_ids, text_mask, clip_images = ( text_input_ids, text_mask, clip_images = (
......
...@@ -577,9 +577,10 @@ def main(): ...@@ -577,9 +577,10 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
) )
# Freeze vae and text_encoder # Freeze vae and text_encoder and set unet to trainable
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
unet.train()
# Create EMA for the unet. # Create EMA for the unet.
if args.use_ema: if args.use_ema:
...@@ -854,29 +855,29 @@ def main(): ...@@ -854,29 +855,29 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
......
...@@ -429,7 +429,6 @@ def main(): ...@@ -429,7 +429,6 @@ def main():
# freeze parameters of models to save more memory # freeze parameters of models to save more memory
unet.requires_grad_(False) unet.requires_grad_(False)
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
...@@ -690,29 +689,29 @@ def main(): ...@@ -690,29 +689,29 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
......
...@@ -947,18 +947,25 @@ def main(args): ...@@ -947,18 +947,25 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train() unet.train()
...@@ -967,12 +974,6 @@ def main(args): ...@@ -967,12 +974,6 @@ def main(args):
text_encoder_two.train() text_encoder_two.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
if args.pretrained_vae_model_name_or_path is not None: if args.pretrained_vae_model_name_or_path is not None:
......
...@@ -657,6 +657,8 @@ def main(args): ...@@ -657,6 +657,8 @@ def main(args):
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder_one.requires_grad_(False) text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False) text_encoder_two.requires_grad_(False)
# Set unet as trainable.
unet.train()
# For mixed precision training we cast all non-trainable weigths to half-precision # For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
...@@ -967,29 +969,29 @@ def main(args): ...@@ -967,29 +969,29 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0 train_loss = 0.0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
model_input = batch["model_input"].to(accelerator.device) model_input = batch["model_input"].to(accelerator.device)
......
...@@ -809,18 +809,25 @@ def main(): ...@@ -809,18 +809,25 @@ def main():
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
) )
args.resume_from_checkpoint = None args.resume_from_checkpoint = None
initial_global_step = 0
else: else:
accelerator.print(f"Resuming from checkpoint {path}") accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path)) accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1]) global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) disable=not accelerator.is_local_main_process,
progress_bar.set_description("Steps") )
# keep original embeddings as reference # keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
...@@ -828,12 +835,6 @@ def main(): ...@@ -828,12 +835,6 @@ def main():
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment