Unverified Commit 60e1d883 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fixup no_trainer save logic (#16968)

* Fixup all examples
parent c79bbc3b
...@@ -393,33 +393,39 @@ def main(): ...@@ -393,33 +393,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -436,7 +442,7 @@ def main(): ...@@ -436,7 +442,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -503,34 +503,40 @@ def main(): ...@@ -503,34 +503,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -547,7 +553,7 @@ def main(): ...@@ -547,7 +553,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -549,34 +549,40 @@ def main(): ...@@ -549,34 +549,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -593,7 +599,7 @@ def main(): ...@@ -593,7 +599,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -506,34 +506,40 @@ def main(): ...@@ -506,34 +506,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -550,7 +556,7 @@ def main(): ...@@ -550,7 +556,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -765,34 +765,40 @@ def main(): ...@@ -765,34 +765,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
......
...@@ -771,34 +771,40 @@ def main(): ...@@ -771,34 +771,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -815,7 +821,7 @@ def main(): ...@@ -815,7 +821,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -501,34 +501,40 @@ def main(): ...@@ -501,34 +501,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -545,7 +551,7 @@ def main(): ...@@ -545,7 +551,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -563,11 +563,13 @@ def main(): ...@@ -563,11 +563,13 @@ def main():
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
for epoch in range(args.num_train_epochs): starting_epoch = 0
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# compute num of losses # compute num of losses
......
...@@ -569,33 +569,39 @@ def main(): ...@@ -569,33 +569,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -612,7 +618,7 @@ def main(): ...@@ -612,7 +618,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -454,33 +454,39 @@ def main(): ...@@ -454,33 +454,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -497,7 +503,7 @@ def main(): ...@@ -497,7 +503,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -606,33 +606,39 @@ def main(): ...@@ -606,33 +606,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -649,7 +655,7 @@ def main(): ...@@ -649,7 +655,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -552,34 +552,40 @@ def main(): ...@@ -552,34 +552,40 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
continue if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
# We keep track of the loss at each epoch # We keep track of the loss at each epoch
...@@ -596,7 +602,7 @@ def main(): ...@@ -596,7 +602,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment