Commit 9c102784 authored by Myle Ott's avatar Myle Ott
Browse files

Add training wall time meter

parent f84e1ed4
...@@ -16,7 +16,7 @@ from itertools import chain ...@@ -16,7 +16,7 @@ from itertools import chain
import torch import torch
from fairseq import distributed_utils, optim, utils from fairseq import distributed_utils, optim, utils
from fairseq.meters import AverageMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler from fairseq.optim import lr_scheduler
...@@ -54,6 +54,7 @@ class Trainer(object): ...@@ -54,6 +54,7 @@ class Trainer(object):
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
self._buffered_stats = defaultdict(lambda: []) self._buffered_stats = defaultdict(lambda: [])
self._flat_grads = None self._flat_grads = None
...@@ -109,9 +110,14 @@ class Trainer(object): ...@@ -109,9 +110,14 @@ class Trainer(object):
self.meters = extra_state['train_meters'] self.meters = extra_state['train_meters']
del extra_state['train_meters'] del extra_state['train_meters']
# reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values():
if isinstance(meter, TimeMeter):
meter.reset()
return extra_state return extra_state
def train_step(self, sample, update_params=True): def train_step(self, sample, update_params=True, dummy_batch=False):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get # Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
...@@ -119,6 +125,9 @@ class Trainer(object): ...@@ -119,6 +125,9 @@ class Trainer(object):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
if not dummy_batch:
self.meters['train_wall'].start()
# forward and backward pass # forward and backward pass
sample = self._prepare_sample(sample) sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample) loss, sample_size, logging_output, oom_fwd = self._forward(sample)
...@@ -132,6 +141,16 @@ class Trainer(object): ...@@ -132,6 +141,16 @@ class Trainer(object):
# update parameters # update parameters
if update_params: if update_params:
agg_logging_output = self._update_params()
else:
agg_logging_output = None # buffering updates
if not dummy_batch:
self.meters['train_wall'].stop()
return agg_logging_output
def _update_params(self):
# gather logging outputs from all replicas # gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes'] sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs'] logging_outputs = self._buffered_stats['logging_outputs']
...@@ -186,8 +205,6 @@ class Trainer(object): ...@@ -186,8 +205,6 @@ class Trainer(object):
self.clear_buffered_stats() self.clear_buffered_stats()
return agg_logging_output return agg_logging_output
else:
return None # buffering updates
def _forward(self, sample, eval=False): def _forward(self, sample, eval=False):
loss = None loss = None
...@@ -320,7 +337,7 @@ class Trainer(object): ...@@ -320,7 +337,7 @@ class Trainer(object):
def dummy_train_step(self, dummy_batch): def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator.""" """Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False) self.train_step(dummy_batch, update_params=False, dummy_batch=True)
self.zero_grad() self.zero_grad()
self.clear_buffered_stats() self.clear_buffered_stats()
......
...@@ -185,6 +185,7 @@ def get_training_stats(trainer): ...@@ -185,6 +185,7 @@ def get_training_stats(trainer):
if trainer.get_meter('loss_scale') is not None: if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg) stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['wall'] = round(trainer.get_meter('wall').elapsed_time) stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
stats['train_wall'] = round(trainer.get_meter('train_wall').sum)
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment