Commit d078e54a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

added exit interval for finetuning

parent 825375cf
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""Finetune utilities.""" """Finetune utilities."""
from functools import partial from functools import partial
import sys
import torch import torch
from megatron import get_args from megatron import get_args
...@@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
optimizer, lr_scheduler) optimizer, lr_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
saved_checkpoint = True
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
...@@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment