Commit 90d6eac2 authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Enable custom sampling strategy in MultiCorpusSampledDataset (#639)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/639

Add argument sampling_func in the constructor to enable custom sampling over a list of dataset keys. The default strategy is to sample uniformly as it did previously.

Reviewed By: liezl200

Differential Revision: D14965774

fbshipit-source-id: f3285688a9ae3729c0ba12c22254c1144d0eea9e
parent 17cef3f6
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List from typing import Callable, Dict, List
import numpy as np import numpy as np
...@@ -16,13 +16,13 @@ from . import FairseqDataset ...@@ -16,13 +16,13 @@ from . import FairseqDataset
class MultiCorpusSampledDataset(FairseqDataset): class MultiCorpusSampledDataset(FairseqDataset):
""" """
Stores multiple instances of FairseqDataset together and in every iteration Stores multiple instances of FairseqDataset together and in every iteration
creates a batch by first sampling a dataset occording to a specified creates a batch by first sampling a dataset according to a specified
probability distribution and then getting instances from that dataset. probability distribution and then getting instances from that dataset.
Args: Args:
datasets: an OrderedDict of FairseqDataset instances. datasets: an OrderedDict of FairseqDataset instances.
sampling_dist: the sampling distribution used to select the dataset sampling_func: A function for sampling over list of dataset keys.
from which the batch is created in a given iteration. Default strategy is to sample uniformly.
default_key: string which specifies the default key to be used for default_key: string which specifies the default key to be used for
generating dummy batches etc. generating dummy batches etc.
""" """
...@@ -30,14 +30,17 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -30,14 +30,17 @@ class MultiCorpusSampledDataset(FairseqDataset):
def __init__( def __init__(
self, self,
datasets: Dict[str, FairseqDataset], datasets: Dict[str, FairseqDataset],
sampling_dist: str = "uniform", sampling_func: Callable[[List], int] = (
# Sample from uniform distribution
lambda x: np.random.choice(x, 1).item()
),
default_key: str = "", default_key: str = "",
): ):
super().__init__() super().__init__()
assert isinstance(datasets, OrderedDict) assert isinstance(datasets, OrderedDict)
assert default_key in datasets assert default_key in datasets
self.datasets = datasets self.datasets = datasets
self.sampling_dist = sampling_dist self.sampling_func = sampling_func
self.default_key = default_key self.default_key = default_key
self.total_num_instances = 0 self.total_num_instances = 0
...@@ -105,15 +108,9 @@ class MultiCorpusSampledDataset(FairseqDataset): ...@@ -105,15 +108,9 @@ class MultiCorpusSampledDataset(FairseqDataset):
if len(samples) == 0: if len(samples) == 0:
return None return None
if self.sampling_dist == "uniform": selected_key = self.sampling_func(list(self.datasets.keys()))
candidates = list(self.datasets.keys()) selected_samples = [sample[selected_key] for sample in samples]
selected_key = np.random.choice(candidates, 1).item() return self.datasets[selected_key].collater(selected_samples)
selected_samples = [sample[selected_key] for sample in samples]
return self.datasets[selected_key].collater(selected_samples)
else:
raise NotImplementedError(
"Specified sampling is currently not Implemented."
)
def get_dummy_batch(self, num_tokens: int, max_positions: int): def get_dummy_batch(self, num_tokens: int, max_positions: int):
""" """
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
from collections import OrderedDict
import numpy as np
import torch
from fairseq.data import LanguagePairDataset, TokenBlockDataset
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from tests.test_train import mock_dict
class TestMultiCorpusSampledDataset(unittest.TestCase):
def setUp(self):
d = mock_dict()
tokens_1 = torch.LongTensor([1]).view(1, -1)
tokens_ds1 = TokenBlockDataset(
tokens_1,
sizes=[tokens_1.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
self.dataset_1 = LanguagePairDataset(
tokens_ds1, tokens_ds1.sizes, d, shuffle=False
)
tokens_2 = torch.LongTensor([2]).view(1, -1)
tokens_ds2 = TokenBlockDataset(
tokens_2,
sizes=[tokens_2.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
self.dataset_2 = LanguagePairDataset(
tokens_ds2, tokens_ds2.sizes, d, shuffle=False
)
def _test_sample_helper(
self,
expected_sample_from_first_ds_percentage,
num_samples=1000,
sampling_func=None,
):
# To make sure test is not flaky
np.random.seed(0)
if sampling_func is None:
m = MultiCorpusSampledDataset(
OrderedDict({0: self.dataset_1, 1: self.dataset_2}), default_key=0
)
else:
m = MultiCorpusSampledDataset(
OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
sampling_func=sampling_func,
default_key=0,
)
m.ordered_indices()
count_sample_from_first_dataset = 0
for _ in range(num_samples):
if m.collater([m[0], m[1]])["net_input"]["src_tokens"][0] == 1:
count_sample_from_first_dataset += 1
sample_from_first_ds_percentage = (
1.0 * count_sample_from_first_dataset / num_samples
)
self.assertLess(
abs(
sample_from_first_ds_percentage
- expected_sample_from_first_ds_percentage
),
0.01,
)
def test_multi_corpus_sampled_dataset_uniform_sample(self):
self._test_sample_helper(expected_sample_from_first_ds_percentage=0.5)
def test_multi_corpus_sampled_dataset_weighted_sample(self):
def naive_weighted_sample(weights):
def f(l):
v = np.random.random()
agg = 0
for i, weight in enumerate(weights):
agg += weight
if agg > v:
return i
return f
self._test_sample_helper(
expected_sample_from_first_ds_percentage=0.9,
sampling_func=naive_weighted_sample(weights=[0.9, 0.1]),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment