Commit ee28411f authored by Ning Dong's avatar Ning Dong Committed by Facebook Github Bot
Browse files

Make ConcatDataset work in PytorchTranslateTask multi-path dataset loading (#730)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/730

Pull Request resolved: https://github.com/pytorch/translate/pull/528

Add/modify necessary functions for ConcatDataset to work in PytorchTranslateTask and replace MultiCorpusSampledDataset which doesn't support mixed batch.

Any idea on how to implement collater here for mixed batch? Now I'm just using the collater of the first dataset.

Reviewed By: liezl200

Differential Revision: D15260872

fbshipit-source-id: 14b148c506e9f8ebf4fe60a49f95444d4123d76f
parent 5aebd096
...@@ -8,23 +8,23 @@ ...@@ -8,23 +8,23 @@
import bisect import bisect
import numpy as np import numpy as np
from . import FairseqDataset from . import FairseqDataset
class ConcatDataset(FairseqDataset): class ConcatDataset(FairseqDataset):
@staticmethod @staticmethod
def cumsum(sequence, sample_ratios): def cumsum(sequence, sample_ratios):
r, s = [], 0 r, s = [], 0
for e, ratio in zip(sequence, sample_ratios): for e, ratio in zip(sequence, sample_ratios):
l = ratio * len(e) curr_len = int(ratio * len(e))
r.append(l + s) r.append(curr_len + s)
s += l s += curr_len
return r return r
def __init__(self, datasets, sample_ratios=1): def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__() super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable' assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets) self.datasets = list(datasets)
if isinstance(sample_ratios, int): if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets) sample_ratios = [sample_ratios] * len(self.datasets)
...@@ -36,21 +36,47 @@ class ConcatDataset(FairseqDataset): ...@@ -36,21 +36,47 @@ class ConcatDataset(FairseqDataset):
return self.cumulative_sizes[-1] return self.cumulative_sizes[-1]
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx][sample_idx]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0: if dataset_idx == 0:
sample_idx = idx sample_idx = idx
else: else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx] sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx] return dataset_idx, sample_idx
def collater(self, samples):
# For now only supports datasets with same underlying collater implementations
return self.datasets[0].collater(samples)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
@property @property
def sizes(self): def sizes(self):
return np.concatenate([np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)]) return np.concatenate(
[np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)]
)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return any(getattr(d, 'supports_prefetch', False) for d in self.datasets) return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
return np.argsort(self.sizes)
def prefetch(self, indices): def prefetch(self, indices):
frm = 0 frm = 0
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
import torch
from fairseq.data import LanguagePairDataset, TokenBlockDataset
from fairseq.data.concat_dataset import ConcatDataset
from tests.test_train import mock_dict
class TestConcatDataset(unittest.TestCase):
def setUp(self):
d = mock_dict()
tokens_1 = torch.LongTensor([1]).view(1, -1)
tokens_ds1 = TokenBlockDataset(
tokens_1,
sizes=[tokens_1.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
self.dataset_1 = LanguagePairDataset(
tokens_ds1, tokens_ds1.sizes, d, shuffle=False
)
tokens_2 = torch.LongTensor([2]).view(1, -1)
tokens_ds2 = TokenBlockDataset(
tokens_2,
sizes=[tokens_2.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
self.dataset_2 = LanguagePairDataset(
tokens_ds2, tokens_ds2.sizes, d, shuffle=False
)
def test_concat_dataset_basics(self):
d = ConcatDataset(
[self.dataset_1, self.dataset_2]
)
assert(len(d) == 2)
assert(d[0]['source'][0] == 1)
assert(d[1]['source'][0] == 2)
d = ConcatDataset(
[self.dataset_1, self.dataset_2], sample_ratios=[1, 2]
)
assert(len(d) == 3)
assert(d[0]['source'][0] == 1)
assert(d[1]['source'][0] == 2)
assert(d[2]['source'][0] == 2)
d = ConcatDataset(
[self.dataset_1, self.dataset_2], sample_ratios=[2, 1]
)
assert(len(d) == 3)
assert(d[0]['source'][0] == 1)
assert(d[1]['source'][0] == 1)
assert(d[2]['source'][0] == 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment