Unverified Commit cef2e40e authored by Timothy Blattner's avatar Timothy Blattner Committed by GitHub
Browse files

Fix for checkpoint rename race condition (#28364)



* Changed logic for renaming staging directory when saving checkpoint to only operate with the main process.
Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.

* Updated styling using make fixup

* Updated check for main process to use built-in versions from trainer
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* Fixed incorrect usage of trainer main process checks
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic

* Removed "with open" due to not working with directory. os.open seems to work for directories.

---------
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
parent fff8ca8e
......@@ -2391,17 +2391,23 @@ class Trainer:
# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process starting on process 0
if staging_output_dir != output_dir:
with self.args.main_process_first(
desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
):
# Then go through the rewriting process, only renaming and rotating from main process(es)
if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
if staging_output_dir != output_dir:
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# Ensure rename completed in cases where os.rename is not atomic
fd = os.open(output_dir, os.O_RDONLY)
os.fsync(fd)
os.close(fd)
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
self.args.distributed_state.wait_for_everyone()
def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment