Unverified Commit 9c82b68f authored by Humphrey009's avatar Humphrey009 Committed by GitHub
Browse files

fix problem of 'accelerator.is_main_process' to run in mutiple GPUs (#5340)



fix problem of 'accelerator.is_main_process' to run in mutiple GPUs or NPUs
Co-authored-by: default avatarjiaqiw <wangjiaqi50@huawei.com>
parent d3e0750d
...@@ -607,6 +607,7 @@ def main(args): ...@@ -607,6 +607,7 @@ def main(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None: if args.checkpoints_total_limit is not None:
...@@ -628,7 +629,6 @@ def main(args): ...@@ -628,7 +629,6 @@ def main(args):
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint) shutil.rmtree(removing_checkpoint)
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment