"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e547458c43dfdbbb8f6a7757237e234c44e20a8f"
Unverified Commit 4f299b24 authored by Nicholas Broad's avatar Nicholas Broad Committed by GitHub
Browse files

Accelerator end training (#18910)

* add accelerator.end_training()

Some trackers need this to end their runs.

* fixup and quality

* add space

* add space again ?!?
parent 7a811894
...@@ -553,6 +553,9 @@ def main(): ...@@ -553,6 +553,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -648,6 +648,9 @@ def main(): ...@@ -648,6 +648,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -693,6 +693,9 @@ def main(): ...@@ -693,6 +693,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -637,6 +637,9 @@ def main(): ...@@ -637,6 +637,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -662,6 +662,9 @@ def main(): ...@@ -662,6 +662,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -590,6 +590,9 @@ def main(): ...@@ -590,6 +590,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -746,6 +746,9 @@ def main(): ...@@ -746,6 +746,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
...@@ -728,6 +728,9 @@ def main(): ...@@ -728,6 +728,9 @@ def main():
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if args.with_tracking:
accelerator.end_training()
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment