"tests/python/vscode:/vscode.git/clone" did not exist on "7c51cd16436c2d774be63c0cec8f222dadf01148"
Commit c6d6256b authored by Myle Ott's avatar Myle Ott
Browse files

Add `--log-format` option and JSON logger

parent 50fdf591
...@@ -18,6 +18,8 @@ def get_parser(desc): ...@@ -18,6 +18,8 @@ def get_parser(desc):
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N', parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N updates (when progress bar is disabled)') help='log progress every N updates (when progress bar is disabled)')
parser.add_argument('--log-format', default='tqdm', help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
return parser return parser
......
...@@ -7,35 +7,29 @@ ...@@ -7,35 +7,29 @@
# #
""" """
Progress bar wrapper around tqdm which handles non-TTY outputs. Wrapper around various loggers and progress bars (e.g., tqdm).
""" """
from collections import OrderedDict from collections import OrderedDict
import json
from numbers import Number from numbers import Number
import sys import sys
from tqdm import tqdm from tqdm import tqdm
from fairseq.meters import AverageMeter
class progress_bar(tqdm):
enabled = sys.stderr.isatty()
print_interval = 1000
def __new__(cls, *args, **kwargs): class progress_bar(object):
if cls.enabled: """Abstract class for progress bars."""
return tqdm(*args, **kwargs) def __init__(self, iterable, epoch=None, prefix=None):
else:
return simple_progress_bar(cls.print_interval, *args, **kwargs)
class simple_progress_bar(object):
"""A minimal replacement for tqdm in non-TTY environments."""
def __init__(self, print_interval, iterable, desc=None, *_args, **_kwargs):
super().__init__()
self.print_interval = print_interval
self.iterable = iterable self.iterable = iterable
self.desc = desc self.epoch = epoch
self.prefix = ''
if epoch is not None:
self.prefix += f'| epoch {epoch:03d}'
if prefix is not None:
self.prefix += f' | {prefix}'
def __enter__(self): def __enter__(self):
return self return self
...@@ -44,36 +38,149 @@ class simple_progress_bar(object): ...@@ -44,36 +38,149 @@ class simple_progress_bar(object):
return False return False
def __iter__(self): def __iter__(self):
size = len(self.iterable) raise NotImplementedError
for i, obj in enumerate(self.iterable):
yield obj def log(self, stats):
if i > 0 and i % self.print_interval == 0: """Log intermediate stats according to log_interval."""
desc = '' if self.desc is None else '{}: '.format(self.desc) raise NotImplementedError
msg = '{}{:5d} / {:d} {}\n'.format(desc, i, size, self.postfix)
sys.stdout.write(msg) def print(self, stats):
sys.stdout.flush() """Print end-of-epoch stats."""
raise NotImplementedError
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): def _str_commas(self, stats):
# Sort in alphabetical order to be more deterministic return ', '.join(key + '=' + stats[key].strip()
postfix = OrderedDict([] if ordered_dict is None else ordered_dict) for key in stats.keys())
for key in sorted(kwargs.keys()):
postfix[key] = kwargs[key] def _str_pipes(self, stats):
return ' | '.join(key + ' ' + stats[key].strip()
for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype # Preprocess stats according to datatype
for key in postfix.keys(): for key in postfix.keys():
# Number: limit the length of the string # Number: limit the length of the string
if isinstance(postfix[key], Number): if isinstance(postfix[key], Number):
postfix[key] = '{0:2.3g}'.format(postfix[key]) postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion # Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str): elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key]) postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything # Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix return postfix
self.postfix = ', '.join(key + '=' + postfix[key].strip()
for key in postfix.keys())
class json_progress_bar(progress_bar):
@classmethod """Log output in JSON format."""
def write(cls, s, file=None, end="\n"):
fp = file if file is not None else sys.stdout def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
fp.write(s) super().__init__(iterable, epoch, prefix)
fp.write(end) self.log_interval = log_interval
fp.flush() self.postfix_json = None
def __iter__(self):
size = float(len(self.iterable))
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print("sweep_log: " + json.dumps(stats))
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = stats
def print(self, stats):
"""Print end-of-epoch stats."""
stats = self._format_stats(self.stats, epoch=self.epoch)
print("sweep_log: " + json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix['epoch'] = epoch
if update is not None:
postfix['update'] = update
# Preprocess stats according to datatype
for key in stats.keys():
# Meter: display both current and average value
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
return postfix
class noop_progress_bar(progress_bar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats):
"""Print end-of-epoch stats."""
pass
class simple_progress_bar(progress_bar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
postfix = self._str_commas(self.stats)
print(f'{self.prefix}: {i:5d} / {size:d} {postfix}')
sys.stdout.flush()
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print(f'{self.prefix} | {postfix}')
sys.stdout.flush()
class tqdm_progress_bar(progress_bar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write(f'{self.tqdm.desc} | {postfix}')
...@@ -14,7 +14,7 @@ import traceback ...@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, data, models, tokenizer from fairseq import criterions, data, models, progress_bar, tokenizer
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
...@@ -36,6 +36,18 @@ def build_criterion(args, src_dict, dst_dict): ...@@ -36,6 +36,18 @@ def build_criterion(args, src_dict, dst_dict):
return criterions.CrossEntropyCriterion(args, dst_dict) return criterions.CrossEntropyCriterion(args, dst_dict)
def build_progress_bar(args, iterator, epoch=None, prefix=None):
if args.log_format == 'json':
bar = progress_bar.json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = progress_bar.noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'tqdm':
bar = progress_bar.tqdm_progress_bar(iterator, epoch, prefix)
else:
bar = progress_bar.simple_progress_bar(iterator, epoch, prefix, args.log_interval)
return bar
def torch_persistent_save(*args, **kwargs): def torch_persistent_save(*args, **kwargs):
for i in range(3): for i in range(3):
try: try:
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
from fairseq import bleu, data, options, tokenizer, utils from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -27,10 +26,10 @@ def main(): ...@@ -27,10 +26,10 @@ def main():
options.add_generation_args(parser) options.add_generation_args(parser)
args = parser.parse_args() args = parser.parse_args()
if args.no_progress_bar:
args.log_format = 'none'
print(args) print(args)
if args.no_progress_bar:
progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dataset
...@@ -74,7 +73,7 @@ def main(): ...@@ -74,7 +73,7 @@ def main():
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions, args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0 num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t: with utils.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter() wps_meter = TimeMeter()
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
...@@ -119,7 +118,7 @@ def main(): ...@@ -119,7 +118,7 @@ def main():
scorer.add(target_tokens, hypo_tokens) scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False) t.log({'wps': round(wps_meter.avg)})
num_sentences += 1 num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
......
...@@ -15,7 +15,6 @@ import math ...@@ -15,7 +15,6 @@ import math
from fairseq import data, options, utils from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
def main(): def main():
...@@ -38,8 +37,7 @@ def main(): ...@@ -38,8 +37,7 @@ def main():
args = utils.parse_args_and_arch(parser) args = utils.parse_args_and_arch(parser)
if args.no_progress_bar: if args.no_progress_bar:
progress_bar.enabled = False args.log_format = 'simple'
progress_bar.print_interval = args.log_interval
if not os.path.exists(args.save_dir): if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir) os.makedirs(args.save_dir)
...@@ -124,7 +122,7 @@ def main(): ...@@ -124,7 +122,7 @@ def main():
def get_perplexity(loss): def get_perplexity(loss):
try: try:
return math.pow(2, loss) return round(math.pow(2, loss), 2)
except OverflowError: except OverflowError:
return float('inf') return float('inf')
...@@ -149,9 +147,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -149,9 +147,8 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
clip_meter = AverageMeter() # % of updates clipped clip_meter = AverageMeter() # % of updates clipped
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
lr = trainer.get_lr() lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t: with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss_dict = trainer.train_step(sample) loss_dict = trainer.train_step(sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
...@@ -168,16 +165,16 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -168,16 +165,16 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
extra_postfix = [] extra_postfix = []
for k, v in loss_dict.items(): for k, v in loss_dict.items():
extra_meters[k].update(v) extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg))) extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([ t.log(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)), ('loss', loss_meter),
('wps', '{:5d}'.format(round(wps_meter.avg))), ('wps', round(wps_meter.avg)),
('wpb', '{:5d}'.format(round(wpb_meter.avg))), ('wpb', round(wpb_meter.avg)),
('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('bsz', round(bsz_meter.avg)),
('lr', lr), ('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('clip', '{:.0%}'.format(clip_meter.avg)),
] + extra_postfix), refresh=False) ] + extra_postfix))
if i == 0: if i == 0:
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
...@@ -185,17 +182,19 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): ...@@ -185,17 +182,19 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
if args.save_interval > 0 and (i + 1) % args.save_interval == 0: if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1) save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format( t.print(collections.OrderedDict([
loss_meter.avg, get_perplexity(loss_meter.avg)) ('train loss', round(loss_meter.avg, 2)),
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format( ('train ppl', get_perplexity(loss_meter.avg)),
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg)) ('s/checkpoint', round(wps_meter.elapsed_time)),
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format( ('words/s', round(wps_meter.avg)),
round(bsz_meter.avg), lr, clip_meter.avg * 100) ('words/batch', round(wpb_meter.avg)),
fmt += ''.join( ('bsz', round(bsz_meter.avg)),
' | {} {:.4f}'.format(k, meter.avg) ('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items() for k, meter in extra_meters.items()
) ]))
t.write(fmt)
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
...@@ -232,8 +231,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): ...@@ -232,8 +231,8 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
loss_meter = AverageMeter() loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) prefix = 'valid on \'{}\' subset'.format(subset)
with progress_bar(itr, desc, leave=False) as t: with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus): for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample) loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
...@@ -245,23 +244,22 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): ...@@ -245,23 +244,22 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
extra_postfix = [] extra_postfix = []
for k, v in loss_dict.items(): for k, v in loss_dict.items():
extra_meters[k].update(v) extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg))) extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([ t.log(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)), ('valid loss', round(loss_meter.avg, 2)),
] + extra_postfix), refresh=False) ] + extra_postfix))
val_loss = loss_meter.avg t.print(collections.OrderedDict([
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format( ('valid loss', round(loss_meter.avg, 2)),
val_loss, get_perplexity(val_loss)) ('valid ppl', get_perplexity(loss_meter.avg)),
fmt += ''.join( ] + [
' | {} {:.4f}'.format(k, meter.avg) (k, meter.avg)
for k, meter in extra_meters.items() for k, meter in extra_meters.items()
) ]))
t.write(fmt)
# update and return the learning rate # update and return the learning rate
return val_loss return loss_meter.avg
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment