Commit 3b09b98b authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed train valid epoch iter

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/866

Differential Revision: D17517115

fbshipit-source-id: fd6921e642c99e37fce6ad58b24c93e70a5364e5
parent 10f9349e
...@@ -24,7 +24,7 @@ class FairseqTask(object): ...@@ -24,7 +24,7 @@ class FairseqTask(object):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.datasets = {} self.datasets = {}
self.epoch_iter = None self.dataset_to_epoch_iter = {}
@classmethod @classmethod
def load_dictionary(cls, filename): def load_dictionary(cls, filename):
...@@ -120,7 +120,6 @@ class FairseqTask(object): ...@@ -120,7 +120,6 @@ class FairseqTask(object):
(default: 0). (default: 0).
epoch (int, optional): the epoch to start the iterator from epoch (int, optional): the epoch to start the iterator from
(default: 0). (default: 0).
Returns: Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split given dataset split
...@@ -128,8 +127,8 @@ class FairseqTask(object): ...@@ -128,8 +127,8 @@ class FairseqTask(object):
# For default fairseq task, return same iterator across epochs # For default fairseq task, return same iterator across epochs
# as datasets are not dynamic, can be overridden in task specific # as datasets are not dynamic, can be overridden in task specific
# setting. # setting.
if self.epoch_iter is not None: if dataset in self.dataset_to_epoch_iter:
return self.epoch_iter return self.dataset_to_epoch_iter[dataset]
assert isinstance(dataset, FairseqDataset) assert isinstance(dataset, FairseqDataset)
...@@ -153,7 +152,7 @@ class FairseqTask(object): ...@@ -153,7 +152,7 @@ class FairseqTask(object):
) )
# return a reusable, sharded iterator # return a reusable, sharded iterator
self.epoch_iter = iterators.EpochBatchIterator( epoch_iter = iterators.EpochBatchIterator(
dataset=dataset, dataset=dataset,
collate_fn=dataset.collater, collate_fn=dataset.collater,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
...@@ -163,7 +162,8 @@ class FairseqTask(object): ...@@ -163,7 +162,8 @@ class FairseqTask(object):
num_workers=num_workers, num_workers=num_workers,
epoch=epoch, epoch=epoch,
) )
return self.epoch_iter self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
def build_model(self, args): def build_model(self, args):
""" """
......
...@@ -286,7 +286,7 @@ class MultiLingualMaskedLMTask(FairseqTask): ...@@ -286,7 +286,7 @@ class MultiLingualMaskedLMTask(FairseqTask):
): ):
# Recreate epoch iterator every epoch cause the underlying # Recreate epoch iterator every epoch cause the underlying
# datasets are dynamic due to sampling. # datasets are dynamic due to sampling.
self.epoch_iter = None self.dataset_to_epoch_iter = None
return super().get_batch_iterator( return super().get_batch_iterator(
dataset, max_tokens, max_sentences, max_positions, dataset, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, required_batch_size_multiple, ignore_invalid_inputs, required_batch_size_multiple,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment