Commit eef6663c authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add checkpoint write timer

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/613

Differential Revision: D14712311

Pulled By: myleott

fbshipit-source-id: 3e7646629b539c10b6af89dece2c0c564f31125f
parent e88ad84b
...@@ -11,8 +11,8 @@ Train a new model on one or across multiple GPUs. ...@@ -11,8 +11,8 @@ Train a new model on one or across multiple GPUs.
import collections import collections
import itertools import itertools
import os
import math import math
import os
import random import random
import torch import torch
...@@ -282,6 +282,10 @@ def get_perplexity(loss): ...@@ -282,6 +282,10 @@ def get_perplexity(loss):
def save_checkpoint(args, trainer, epoch_itr, val_loss): def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args): if args.no_save or not distributed_utils.is_master(args):
return return
write_timer = StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch() end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates() updates = trainer.get_num_updates()
...@@ -330,6 +334,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -330,6 +334,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
def load_checkpoint(args, trainer, epoch_itr): def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match.""" """Load a checkpoint and replay dataloader to match."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment