Commit 7432130e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

rm default_key from MultiCorpusSampledDataset

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/575

Differential Revision: D15318004

Pulled By: myleott

fbshipit-source-id: ad918d71b1bd8074decf5ec3463dd9bc9487bbe9
parent 2c278ff0
...@@ -28,24 +28,19 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -28,24 +28,19 @@ class MultiCorpusSampledDataset(FairseqDataset):
datasets: an OrderedDict of FairseqDataset instances. datasets: an OrderedDict of FairseqDataset instances.
sampling_func: A function for sampling over list of dataset keys. sampling_func: A function for sampling over list of dataset keys.
Default strategy is to sample uniformly. Default strategy is to sample uniformly.
default_key: string which specifies the default key to be used for
generating dummy batches etc.
""" """
def __init__( def __init__(
self, self,
datasets: Dict[str, FairseqDataset], datasets: Dict[str, FairseqDataset],
sampling_func: Callable[[List], int] = None, sampling_func: Callable[[List], int] = None,
default_key: str = "",
): ):
super().__init__() super().__init__()
assert isinstance(datasets, OrderedDict) assert isinstance(datasets, OrderedDict)
assert default_key in datasets
self.datasets = datasets self.datasets = datasets
if sampling_func is None: if sampling_func is None:
sampling_func = uniform_sampler sampling_func = uniform_sampler
self.sampling_func = sampling_func self.sampling_func = sampling_func
self.default_key = default_key
self.total_num_instances = 0 self.total_num_instances = 0
for _, dataset in datasets.items(): for _, dataset in datasets.items():
......
...@@ -62,7 +62,6 @@ class CrossLingualLMTask(FairseqTask): ...@@ -62,7 +62,6 @@ class CrossLingualLMTask(FairseqTask):
self.seed = args.seed self.seed = args.seed
self.distributed_world_size = args.distributed_world_size self.distributed_world_size = args.distributed_world_size
self.langs2id = self._lang_to_id(args.monolingual_langs) self.langs2id = self._lang_to_id(args.monolingual_langs)
self.default_key = None
def _lang_to_id( def _lang_to_id(
self, self,
...@@ -155,9 +154,6 @@ class CrossLingualLMTask(FairseqTask): ...@@ -155,9 +154,6 @@ class CrossLingualLMTask(FairseqTask):
dataset_map = OrderedDict() dataset_map = OrderedDict()
for lang in self.langs2id.keys(): for lang in self.langs2id.keys():
if self.default_key is None:
self.default_key = lang
# Datasets are expected to be in "split.lang" format (Eg: train.en) # Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang) language_split = '{}.{}'.format(split, lang)
...@@ -177,9 +173,7 @@ class CrossLingualLMTask(FairseqTask): ...@@ -177,9 +173,7 @@ class CrossLingualLMTask(FairseqTask):
seed=self.seed, seed=self.seed,
) )
self.datasets[split] = MultiCorpusSampledDataset( self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
dataset_map, default_key=self.default_key
)
print('| {} {} {} examples'.format( print('| {} {} {} examples'.format(
self.args.data.split(':')[epoch], split, len(self.datasets[split])) self.args.data.split(':')[epoch], split, len(self.datasets[split]))
) )
...@@ -53,13 +53,12 @@ class TestMultiCorpusSampledDataset(unittest.TestCase): ...@@ -53,13 +53,12 @@ class TestMultiCorpusSampledDataset(unittest.TestCase):
np.random.seed(0) np.random.seed(0)
if sampling_func is None: if sampling_func is None:
m = MultiCorpusSampledDataset( m = MultiCorpusSampledDataset(
OrderedDict({0: self.dataset_1, 1: self.dataset_2}), default_key=0 OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
) )
else: else:
m = MultiCorpusSampledDataset( m = MultiCorpusSampledDataset(
OrderedDict({0: self.dataset_1, 1: self.dataset_2}), OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
sampling_func=sampling_func, sampling_func=sampling_func,
default_key=0,
) )
m.ordered_indices() m.ordered_indices()
count_sample_from_first_dataset = 0 count_sample_from_first_dataset = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment