Commit d078e54a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

added exit interval for finetuning

parent 825375cf
......@@ -16,7 +16,7 @@
"""Finetune utilities."""
from functools import partial
import sys
import torch
from megatron import get_args
......@@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
optimizer, lr_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
saved_checkpoint = True
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
......@@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model,
iteration, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment