Commit c6d6256b authored by Myle Ott's avatar Myle Ott
Browse files

Add `--log-format` option and JSON logger

parent 50fdf591
......@@ -18,6 +18,8 @@ def get_parser(desc):
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N updates (when progress bar is disabled)')
parser.add_argument('--log-format', default='tqdm', help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
return parser
......
......@@ -7,35 +7,29 @@
#
"""
Progress bar wrapper around tqdm which handles non-TTY outputs.
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
from collections import OrderedDict
import json
from numbers import Number
import sys
from tqdm import tqdm
from fairseq.meters import AverageMeter
class progress_bar(tqdm):
enabled = sys.stderr.isatty()
print_interval = 1000
def __new__(cls, *args, **kwargs):
if cls.enabled:
return tqdm(*args, **kwargs)
else:
return simple_progress_bar(cls.print_interval, *args, **kwargs)
class simple_progress_bar(object):
"""A minimal replacement for tqdm in non-TTY environments."""
def __init__(self, print_interval, iterable, desc=None, *_args, **_kwargs):
super().__init__()
self.print_interval = print_interval
class progress_bar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.desc = desc
self.epoch = epoch
self.prefix = ''
if epoch is not None:
self.prefix += f'| epoch {epoch:03d}'
if prefix is not None:
self.prefix += f' | {prefix}'
def __enter__(self):
return self
......@@ -44,36 +38,149 @@ class simple_progress_bar(object):
return False
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if i > 0 and i % self.print_interval == 0:
desc = '' if self.desc is None else '{}: '.format(self.desc)
msg = '{}{:5d} / {:d} {}\n'.format(desc, i, size, self.postfix)
sys.stdout.write(msg)
sys.stdout.flush()
raise NotImplementedError
def log(self, stats):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats):
"""Print end-of-epoch stats."""
raise NotImplementedError
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
# Sort in alphabetical order to be more deterministic
postfix = OrderedDict([] if ordered_dict is None else ordered_dict)
for key in sorted(kwargs.keys()):
postfix[key] = kwargs[key]
def _str_commas(self, stats):
return ', '.join(key + '=' + stats[key].strip()
for key in stats.keys())
def _str_pipes(self, stats):
return ' | '.join(key + ' ' + stats[key].strip()
for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
# Number: limit the length of the string
if isinstance(postfix[key], Number):
postfix[key] = '{0:2.3g}'.format(postfix[key])
postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
self.postfix = ', '.join(key + '=' + postfix[key].strip()
for key in postfix.keys())
@classmethod
def write(cls, s, file=None, end="\n"):
fp = file if file is not None else sys.stdout
fp.write(s)
fp.write(end)
fp.flush()
return postfix
class json_progress_bar(progress_bar):
"""Log output in JSON format."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.postfix_json = None
def __iter__(self):
size = float(len(self.iterable))
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print("sweep_log: " + json.dumps(stats))
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = stats
def print(self, stats):
"""Print end-of-epoch stats."""
stats = self._format_stats(self.stats, epoch=self.epoch)
print("sweep_log: " + json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix['epoch'] = epoch
if update is not None:
postfix['update'] = update
# Preprocess stats according to datatype
for key in stats.keys():
# Meter: display both current and average value
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
return postfix
class noop_progress_bar(progress_bar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats):
"""Print end-of-epoch stats."""
pass
class simple_progress_bar(progress_bar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
postfix = self._str_commas(self.stats)
print(f'{self.prefix}: {i:5d} / {size:d} {postfix}')
sys.stdout.flush()
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print(f'{self.prefix} | {postfix}')
sys.stdout.flush()
class tqdm_progress_bar(progress_bar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write(f'{self.tqdm.desc} | {postfix}')
......@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, data, models, tokenizer
from fairseq import criterions, data, models, progress_bar, tokenizer
def parse_args_and_arch(parser):
......@@ -36,6 +36,18 @@ def build_criterion(args, src_dict, dst_dict):
return criterions.CrossEntropyCriterion(args, dst_dict)
def build_progress_bar(args, iterator, epoch=None, prefix=None):
if args.log_format == 'json':
bar = progress_bar.json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = progress_bar.noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'tqdm':
bar = progress_bar.tqdm_progress_bar(iterator, epoch, prefix)
else:
bar = progress_bar.simple_progress_bar(iterator, epoch, prefix, args.log_interval)
return bar
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
......
......@@ -11,7 +11,6 @@ import torch
from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
......@@ -27,10 +26,10 @@ def main():
options.add_generation_args(parser)
args = parser.parse_args()
if args.no_progress_bar:
args.log_format = 'none'
print(args)
if args.no_progress_bar:
progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
......@@ -74,7 +73,7 @@ def main():
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
with utils.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
......@@ -119,7 +118,7 @@ def main():
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
t.log({'wps': round(wps_meter.avg)})
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
......
......@@ -15,7 +15,6 @@ import math
from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
def main():
......@@ -38,8 +37,7 @@ def main():
args = utils.parse_args_and_arch(parser)
if args.no_progress_bar:
progress_bar.enabled = False
progress_bar.print_interval = args.log_interval
args.log_format = 'simple'
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
......@@ -124,7 +122,7 @@ def main():
def get_perplexity(loss):
try:
return math.pow(2, loss)
return round(math.pow(2, loss), 2)
except OverflowError:
return float('inf')
......@@ -149,9 +147,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
clip_meter = AverageMeter() # % of updates clipped
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
......@@ -168,16 +165,16 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
('wps', '{:5d}'.format(round(wps_meter.avg))),
('wpb', '{:5d}'.format(round(wpb_meter.avg))),
('bsz', '{:5d}'.format(round(bsz_meter.avg))),
t.log(collections.OrderedDict([
('loss', loss_meter),
('wps', round(wps_meter.avg)),
('wpb', round(wpb_meter.avg)),
('bsz', round(bsz_meter.avg)),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + extra_postfix), refresh=False)
('clip', '{:.0%}'.format(clip_meter.avg)),
] + extra_postfix))
if i == 0:
# ignore the first mini-batch in words-per-second calculation
......@@ -185,17 +182,19 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, get_perplexity(loss_meter.avg))
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
round(bsz_meter.avg), lr, clip_meter.avg * 100)
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)),
('train ppl', get_perplexity(loss_meter.avg)),
('s/checkpoint', round(wps_meter.elapsed_time)),
('words/s', round(wps_meter.avg)),
('words/batch', round(wpb_meter.avg)),
('bsz', round(bsz_meter.avg)),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
]))
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
......@@ -232,8 +231,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
with progress_bar(itr, desc, leave=False) as t:
prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
......@@ -245,23 +244,22 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)),
] + extra_postfix), refresh=False)
t.log(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
] + extra_postfix))
val_loss = loss_meter.avg
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
val_loss, get_perplexity(val_loss))
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
t.print(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
('valid ppl', get_perplexity(loss_meter.avg)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
]))
# update and return the learning rate
return val_loss
return loss_meter.avg
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment