Commit a58c1127 authored by freewym's avatar freewym Committed by Facebook Github Bot
Browse files

fix log printing in progress bar (#778)

Summary:
In the current progress bar, the counter for log_interval will always start from 0, which is not correct if  reloading from a checkpoint in the middle of an epoch. This fix obtains the offset from the iterator to set the counter correctly.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/778

Differential Revision: D15739953

Pulled By: myleott

fbshipit-source-id: a1d13403ec5783b22e01d7cb63874fd8dea7f8b0
parent 1ca075a2
......@@ -208,6 +208,7 @@ class GroupedIterator(object):
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.offset = int(math.ceil(getattr(iterable, 'count', 0) / float(chunk_size)))
self.itr = iterable
self.chunk_size = chunk_size
......
......@@ -72,6 +72,7 @@ class progress_bar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.offset = getattr(iterable, 'offset', 0)
self.epoch = epoch
self.prefix = ''
if epoch is not None:
......@@ -122,7 +123,7 @@ class json_progress_bar(progress_bar):
def __iter__(self):
size = float(len(self.iterable))
for i, obj in enumerate(self.iterable):
for i, obj in enumerate(self.iterable, start=self.offset):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
......@@ -183,7 +184,7 @@ class simple_progress_bar(progress_bar):
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
for i, obj in enumerate(self.iterable, start=self.offset):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment