Commit 534eaa2c authored by freewym's avatar freewym Committed by Facebook Github Bot
Browse files

add and set the missing state "shuffle" properly to EpochBatchIterato… (#1375)

Summary:
…r to correctly recover the training from a "non-shuffle" checkpoint
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1375

Differential Revision: D18566535

Pulled By: myleott

fbshipit-source-id: ff7b1a6ead708801f537ec7885e30e37168cd34b
parent 4fd2a16b
......@@ -176,6 +176,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.num_workers = num_workers
self.epoch = epoch
self.shuffle = True
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
......@@ -202,6 +203,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
self.dataset.set_epoch(self.epoch)
self.shuffle = shuffle
return self._cur_epoch_itr
def end_of_epoch(self) -> bool:
......@@ -222,6 +224,7 @@ class EpochBatchIterator(EpochBatchIterating):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
'shuffle': self.shuffle,
}
def load_state_dict(self, state_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment