Commit 7ca56cb8 authored by freewym's avatar freewym Committed by Facebook Github Bot
Browse files

add set_epoch() to class ConcatDataset and ConcatSentenceDataset to c… (#1272)

Summary:
…all set_epoch() for each sub dataset
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1272

Differential Revision: D18338300

Pulled By: myleott

fbshipit-source-id: 973d57f52c5cf4ad40122d4a625942281c7983b7
parent f03392d1
......@@ -97,3 +97,9 @@ class ConcatDataset(FairseqDataset):
if getattr(ds, 'supports_prefetch', False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, 'set_epoch'):
ds.set_epoch(epoch)
......@@ -48,3 +48,9 @@ class ConcatSentencesDataset(FairseqDataset):
for ds in self.datasets:
if getattr(ds, 'supports_prefetch', False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, 'set_epoch'):
ds.set_epoch(epoch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment