"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "57c965a8f1000b4016ad219e616f509d8af3f5b5"
Unverified Commit 60e1d883 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Fixup no_trainer save logic (#16968)

* Fixup all examples
parent c79bbc3b
...@@ -393,32 +393,38 @@ def main(): ...@@ -393,32 +393,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -436,7 +442,7 @@ def main(): ...@@ -436,7 +442,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -503,33 +503,39 @@ def main(): ...@@ -503,33 +503,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -547,7 +553,7 @@ def main(): ...@@ -547,7 +553,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -549,33 +549,39 @@ def main(): ...@@ -549,33 +549,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -593,7 +599,7 @@ def main(): ...@@ -593,7 +599,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -506,33 +506,39 @@ def main(): ...@@ -506,33 +506,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -550,7 +556,7 @@ def main(): ...@@ -550,7 +556,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -765,33 +765,39 @@ def main(): ...@@ -765,33 +765,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
......
...@@ -771,33 +771,39 @@ def main(): ...@@ -771,33 +771,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -815,7 +821,7 @@ def main(): ...@@ -815,7 +821,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -501,33 +501,39 @@ def main(): ...@@ -501,33 +501,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -545,7 +551,7 @@ def main(): ...@@ -545,7 +551,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -563,11 +563,13 @@ def main(): ...@@ -563,11 +563,13 @@ def main():
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}")
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
for epoch in range(args.num_train_epochs): starting_epoch = 0
for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# compute num of losses # compute num of losses
......
...@@ -569,32 +569,38 @@ def main(): ...@@ -569,32 +569,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -612,7 +618,7 @@ def main(): ...@@ -612,7 +618,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -454,32 +454,38 @@ def main(): ...@@ -454,32 +454,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -497,7 +503,7 @@ def main(): ...@@ -497,7 +503,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -606,32 +606,38 @@ def main(): ...@@ -606,32 +606,38 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -649,7 +655,7 @@ def main(): ...@@ -649,7 +655,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
...@@ -552,33 +552,39 @@ def main(): ...@@ -552,33 +552,39 @@ def main():
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0 completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint) accelerator.load_state(args.resume_from_checkpoint)
resume_step = None path = os.path.basename(args.resume_from_checkpoint)
path = args.resume_from_checkpoint
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
if "epoch" in path: # Extract `epoch_{i}` or `step_{i}`
args.num_train_epochs -= int(path.replace("epoch_", "")) training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else: else:
resume_step = int(path.replace("step_", "")) resume_step = int(training_difference.replace("step_", ""))
args.num_train_epochs -= resume_step // len(train_dataloader) starting_epoch = resume_step // len(train_dataloader)
resume_step = (args.num_train_epochs * len(train_dataloader)) - resume_step resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(args.num_train_epochs): for epoch in range(starting_epoch, args.num_train_epochs):
model.train() model.train()
if args.with_tracking: if args.with_tracking:
total_loss = 0 total_loss = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step # We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == 0 and step < resume_step: if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue continue
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
...@@ -596,7 +602,7 @@ def main(): ...@@ -596,7 +602,7 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}" output_dir = f"step_{completed_steps }"
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment