Commit 7aba6084 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Update progress_bar to be more robust to changes in tqdm (#21)

parent 2ad58885
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
# #
""" """
Progress bar wrapper around tqdm which handles non-tty outputs Progress bar wrapper around tqdm which handles non-TTY outputs.
""" """
from collections import OrderedDict
from numbers import Number
import sys import sys
from tqdm import tqdm from tqdm import tqdm
...@@ -26,30 +28,51 @@ class progress_bar(tqdm): ...@@ -26,30 +28,51 @@ class progress_bar(tqdm):
return simple_progress_bar(cls.print_interval, *args, **kwargs) return simple_progress_bar(cls.print_interval, *args, **kwargs)
class simple_progress_bar(tqdm): class simple_progress_bar(object):
"""A minimal replacement for tqdm in non-TTY environments."""
def __init__(self, print_interval, *args, **kwargs): def __init__(self, print_interval, iterable, desc, *_args, **_kwargs):
super(simple_progress_bar, self).__init__(*args, **kwargs) super().__init__()
self.print_interval = print_interval self.print_interval = print_interval
self.iterable = iterable
self.desc = desc
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def __iter__(self): def __iter__(self):
size = len(self.iterable) size = len(self.iterable)
for i, obj in enumerate(self.iterable): for i, obj in enumerate(self.iterable):
yield obj yield obj
if i > 0 and i % self.print_interval == 0: if i > 0 and i % self.print_interval == 0:
msg = '{} {:5d} / {:d} {}\n'.format(self.desc, i, size, self.postfix) msg = '{}: {:5d} / {:d} {}\n'.format(self.desc, i, size, self.postfix)
sys.stdout.write(msg) sys.stdout.write(msg)
sys.stdout.flush() sys.stdout.flush()
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
# Sort in alphabetical order to be more deterministic
postfix = OrderedDict([] if ordered_dict is None else ordered_dict)
for key in sorted(kwargs.keys()):
postfix[key] = kwargs[key]
# Preprocess stats according to datatype
for key in postfix.keys():
# Number: limit the length of the string
if isinstance(postfix[key], Number):
postfix[key] = '{0:2.3g}'.format(postfix[key])
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
self.postfix = ', '.join(key + '=' + postfix[key].strip()
for key in postfix.keys())
@classmethod @classmethod
def write(cls, s, file=None, end="\n"): def write(cls, s, file=None, end="\n"):
fp = file if file is not None else sys.stdout fp = file if file is not None else sys.stdout
fp.write(s) fp.write(s)
fp.write(end) fp.write(end)
fp.flush() fp.flush()
@staticmethod
def status_printer(file):
def print_status(s):
pass
return print_status
...@@ -135,7 +135,7 @@ def main(): ...@@ -135,7 +135,7 @@ def main():
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)]) display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0)) wps_meter.update(src.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg))) t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
num_sentences += 1 num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
......
...@@ -148,7 +148,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): ...@@ -148,7 +148,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
('lr', lr), ('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
('gnorm', '{:.4f}'.format(gnorm_meter.avg)), ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
])) ]), refresh=False)
if i == 0: if i == 0:
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
...@@ -182,7 +182,7 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): ...@@ -182,7 +182,7 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
loss = trainer.valid_step(sample, criterion) loss = trainer.valid_step(sample, criterion)
loss_meter.update(loss, ntokens) loss_meter.update(loss, ntokens)
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg)) t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)
val_loss = loss_meter.avg val_loss = loss_meter.avg
t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}' t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment