Commit 534eaa2c authored by freewym's avatar freewym Committed by Facebook Github Bot
Browse files

add and set the missing state "shuffle" properly to EpochBatchIterato… (#1375)

Summary:
…r to correctly recover the training from a "non-shuffle" checkpoint
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1375

Differential Revision: D18566535

Pulled By: myleott

fbshipit-source-id: ff7b1a6ead708801f537ec7885e30e37168cd34b
parent 4fd2a16b
...@@ -176,6 +176,7 @@ class EpochBatchIterator(EpochBatchIterating): ...@@ -176,6 +176,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.num_workers = num_workers self.num_workers = num_workers
self.epoch = epoch self.epoch = epoch
self.shuffle = True
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False) self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
...@@ -202,6 +203,7 @@ class EpochBatchIterator(EpochBatchIterating): ...@@ -202,6 +203,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
) )
self.dataset.set_epoch(self.epoch) self.dataset.set_epoch(self.epoch)
self.shuffle = shuffle
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self) -> bool: def end_of_epoch(self) -> bool:
...@@ -222,6 +224,7 @@ class EpochBatchIterator(EpochBatchIterating): ...@@ -222,6 +224,7 @@ class EpochBatchIterator(EpochBatchIterating):
return { return {
'epoch': self.epoch, 'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch, 'iterations_in_epoch': self.iterations_in_epoch,
'shuffle': self.shuffle,
} }
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment