Commit 1652c372 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix save/load accelerator state

parent daca5721
......@@ -1298,15 +1298,7 @@ def main():
if training_args.max_steps < 0:
# we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step = (cur_step - epochs_trained * steps_per_epoch)
# TODO: currently broken
if resume_step == round(len(vectorized_datasets["train"])/train_batch_size):
resume_step = None
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
epochs_trained += 1
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
else:
# Currently we don't know how many steps we've taken in the current epoch
# So we just shuffle the dataset one extra time and start from a fresh epoch
......@@ -1409,7 +1401,9 @@ def main():
# save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
accelerator.save_state(output_dir=intermediate_dir)
# safe_serialization=False to avoid shared tensors saving issue (TODO: it's a temporary fix)
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment