Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import torch
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
def _flatten(dico, prefix=None):
"""Flatten a nested dictionary."""
new_dico = OrderedDict()
if isinstance(dico, dict):
prefix = prefix + "." if prefix is not None else ""
for k, v in dico.items():
if v is None:
continue
new_dico.update(_flatten(v, prefix + k))
elif isinstance(dico, list):
for i, v in enumerate(dico):
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
else:
new_dico = OrderedDict({prefix: dico})
return new_dico
def _unflatten(dico):
"""Unflatten a flattened dictionary into a nested dictionary."""
new_dico = OrderedDict()
for full_k, v in dico.items():
full_k = full_k.split(".")
node = new_dico
for k in full_k[:-1]:
if k.startswith("[") and k.endswith("]"):
k = int(k[1:-1])
if k not in node:
node[k] = OrderedDict()
node = node[k]
node[full_k[-1]] = v
return new_dico
class NestedDictionaryDataset(FairseqDataset):
def __init__(self, defn, sizes=None):
super().__init__()
self.defn = _flatten(defn)
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
first = None
for v in self.defn.values():
if not isinstance(
v,
(
FairseqDataset,
torch.utils.data.Dataset,
),
):
raise ValueError("Expected Dataset but found: {}".format(v.__class__))
first = first or v
if len(v) > 0:
assert len(v) == len(first), "dataset lengths must match"
self._len = len(first)
def __getitem__(self, index):
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
def __len__(self):
return self._len
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
if len(samples) == 0:
return {}
sample = OrderedDict()
for k, ds in self.defn.items():
try:
sample[k] = ds.collater([s[k] for s in samples])
except NotImplementedError:
sample[k] = default_collate([s[k] for s in samples])
return _unflatten(sample)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(s[index] for s in self.sizes)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if len(self.sizes) == 1:
return self.sizes[0][index]
else:
return (s[index] for s in self.sizes)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return any(ds.supports_prefetch for ds in self.defn.values())
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
for ds in self.defn.values():
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.defn.values():
ds.set_epoch(epoch)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from fairseq.data import data_utils
class WordNoising(object):
"""Generate a noisy version of a sentence, without changing words themselves."""
def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
self.dictionary = dictionary
self.bpe_end = None
if bpe_cont_marker:
self.bpe_end = np.array(
[
not self.dictionary[i].endswith(bpe_cont_marker)
for i in range(len(self.dictionary))
]
)
elif bpe_end_marker:
self.bpe_end = np.array(
[
self.dictionary[i].endswith(bpe_end_marker)
for i in range(len(self.dictionary))
]
)
self.get_word_idx = (
self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
)
def noising(self, x, lengths, noising_prob=0.0):
raise NotImplementedError()
def _get_bpe_word_idx(self, x):
"""
Given a list of BPE tokens, for every index in the tokens list,
return the index of the word grouping that it belongs to.
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
return [[0], [1], [2], [2]].
"""
# x: (T x B)
bpe_end = self.bpe_end[x]
if x.size(0) == 1 and x.size(1) == 1:
# Special case when we only have one word in x. If x = [[N]],
# bpe_end is a scalar (bool) instead of a 2-dim array of bools,
# which makes the sum operation below fail.
return np.array([[0]])
# do a reduce front sum to generate word ids
word_idx = bpe_end[::-1].cumsum(0)[::-1]
word_idx = word_idx.max(0)[None, :] - word_idx
return word_idx
def _get_token_idx(self, x):
"""
This is to extend noising functions to be able to apply to non-bpe
tokens, e.g. word or characters.
"""
x = torch.t(x)
word_idx = np.array([range(len(x_i)) for x_i in x])
return np.transpose(word_idx)
class WordDropout(WordNoising):
"""Randomly drop input words. If not passing blank_idx (default is None),
then dropped words will be removed. Otherwise, it will be replaced by the
blank_idx."""
def __init__(
self,
dictionary,
default_dropout_prob=0.1,
bpe_cont_marker="@@",
bpe_end_marker=None,
):
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
self.default_dropout_prob = default_dropout_prob
def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
if dropout_prob is None:
dropout_prob = self.default_dropout_prob
# x: (T x B), lengths: B
if dropout_prob == 0:
return x, lengths
assert 0 < dropout_prob < 1
# be sure to drop entire words
word_idx = self.get_word_idx(x)
sentences = []
modified_lengths = []
for i in range(lengths.size(0)):
# Since dropout probabilities need to apply over non-pad tokens,
# it is not trivial to generate the keep mask without consider
# input lengths; otherwise, this could be done outside the loop
# We want to drop whole words based on word_idx grouping
num_words = max(word_idx[:, i]) + 1
# ith example: [x0, x1, ..., eos, pad, ..., pad]
# We should only generate keep probs for non-EOS tokens. Thus if the
# input sentence ends in EOS, the last word idx is not included in
# the dropout mask generation and we append True to always keep EOS.
# Otherwise, just generate the dropout mask for all word idx
# positions.
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
if has_eos: # has eos?
keep = np.random.rand(num_words - 1) >= dropout_prob
keep = np.append(keep, [True]) # keep EOS symbol
else:
keep = np.random.rand(num_words) >= dropout_prob
words = x[: lengths[i], i].tolist()
# TODO: speed up the following loop
# drop words from the input according to keep
new_s = [
w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
]
new_s = [w for w in new_s if w is not None]
# we need to have at least one word in the sentence (more than the
# start / end sentence symbols)
if len(new_s) <= 1:
# insert at beginning in case the only token left is EOS
# EOS should be at end of list.
new_s.insert(0, words[np.random.randint(0, len(words))])
assert len(new_s) >= 1 and (
not has_eos # Either don't have EOS at end or last token is EOS
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
), "New sentence is invalid."
sentences.append(new_s)
modified_lengths.append(len(new_s))
# re-construct input
modified_lengths = torch.LongTensor(modified_lengths)
modified_x = torch.LongTensor(
modified_lengths.max(), modified_lengths.size(0)
).fill_(self.dictionary.pad())
for i in range(modified_lengths.size(0)):
modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
return modified_x, modified_lengths
class WordShuffle(WordNoising):
"""Shuffle words by no more than k positions."""
def __init__(
self,
dictionary,
default_max_shuffle_distance=3,
bpe_cont_marker="@@",
bpe_end_marker=None,
):
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
self.default_max_shuffle_distance = 3
def noising(self, x, lengths, max_shuffle_distance=None):
if max_shuffle_distance is None:
max_shuffle_distance = self.default_max_shuffle_distance
# x: (T x B), lengths: B
if max_shuffle_distance == 0:
return x, lengths
# max_shuffle_distance < 1 will return the same sequence
assert max_shuffle_distance > 1
# define noise word scores
noise = np.random.uniform(
0,
max_shuffle_distance,
size=(x.size(0), x.size(1)),
)
noise[0] = -1 # do not move start sentence symbol
# be sure to shuffle entire words
word_idx = self.get_word_idx(x)
x2 = x.clone()
for i in range(lengths.size(0)):
length_no_eos = lengths[i]
if x[lengths[i] - 1, i] == self.dictionary.eos():
length_no_eos = lengths[i] - 1
# generate a random permutation
scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
# ensure no reordering inside a word
scores += 1e-6 * np.arange(length_no_eos.item())
permutation = scores.argsort()
# shuffle words
x2[:length_no_eos, i].copy_(
x2[:length_no_eos, i][torch.from_numpy(permutation)]
)
return x2, lengths
class UnsupervisedMTNoising(WordNoising):
"""
Implements the default configuration for noising in UnsupervisedMT
(github.com/facebookresearch/UnsupervisedMT)
"""
def __init__(
self,
dictionary,
max_word_shuffle_distance,
word_dropout_prob,
word_blanking_prob,
bpe_cont_marker="@@",
bpe_end_marker=None,
):
super().__init__(dictionary)
self.max_word_shuffle_distance = max_word_shuffle_distance
self.word_dropout_prob = word_dropout_prob
self.word_blanking_prob = word_blanking_prob
self.word_dropout = WordDropout(
dictionary=dictionary,
bpe_cont_marker=bpe_cont_marker,
bpe_end_marker=bpe_end_marker,
)
self.word_shuffle = WordShuffle(
dictionary=dictionary,
bpe_cont_marker=bpe_cont_marker,
bpe_end_marker=bpe_end_marker,
)
def noising(self, x, lengths):
# 1. Word Shuffle
noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
x=x,
lengths=lengths,
max_shuffle_distance=self.max_word_shuffle_distance,
)
# 2. Word Dropout
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
x=noisy_src_tokens,
lengths=noisy_src_lengths,
dropout_prob=self.word_dropout_prob,
)
# 3. Word Blanking
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
x=noisy_src_tokens,
lengths=noisy_src_lengths,
dropout_prob=self.word_blanking_prob,
blank_idx=self.dictionary.unk(),
)
return noisy_src_tokens
class NoisingDataset(torch.utils.data.Dataset):
def __init__(
self,
src_dataset,
src_dict,
seed,
noiser=None,
noising_class=UnsupervisedMTNoising,
**kwargs
):
"""
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
samples based on the supplied noising configuration.
Args:
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
collate function.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
src_dict (~fairseq.data.Dictionary): source dictionary
seed (int): seed to use when generating random noise
noiser (WordNoising): a pre-initialized :class:`WordNoising`
instance. If this is None, a new instance will be created using
*noising_class* and *kwargs*.
noising_class (class, optional): class to use to initialize a
default :class:`WordNoising` instance.
kwargs (dict, optional): arguments to initialize the default
:class:`WordNoising` instance given by *noiser*.
"""
self.src_dataset = src_dataset
self.src_dict = src_dict
self.seed = seed
self.noiser = (
noiser
if noiser is not None
else noising_class(
dictionary=src_dict,
**kwargs,
)
)
def __getitem__(self, index):
"""
Returns a single noisy sample. Multiple samples are fed to the collater
create a noising dataset batch.
"""
src_tokens = self.src_dataset[index]
src_lengths = torch.LongTensor([len(src_tokens)])
src_tokens = src_tokens.unsqueeze(0)
# Transpose src tokens to fit expected shape of x in noising function
# (batch size, sequence length) -> (sequence length, batch size)
src_tokens_t = torch.t(src_tokens)
with data_utils.numpy_seed(self.seed + index):
noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
# Transpose back to expected src_tokens format
# (sequence length, 1) -> (1, sequence length)
noisy_src_tokens = torch.t(noisy_src_tokens)
return noisy_src_tokens[0]
def __len__(self):
"""
The length of the noising dataset is the length of src.
"""
return len(self.src_dataset)
@property
def supports_prefetch(self):
return self.src_dataset.supports_prefetch
def prefetch(self, indices):
if self.src_dataset.supports_prefetch:
self.src_dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import FairseqDataset
class NumSamplesDataset(FairseqDataset):
def __getitem__(self, index):
return 1
def __len__(self):
return 0
def collater(self, samples):
return sum(samples)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class NumelDataset(BaseWrapperDataset):
def __init__(self, dataset, reduce=False):
super().__init__(dataset)
self.reduce = reduce
def __getitem__(self, index):
item = self.dataset[index]
if torch.is_tensor(item):
return torch.numel(item)
else:
return np.size(item)
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if self.reduce:
return sum(samples)
else:
return torch.tensor(samples)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class OffsetTokensDataset(BaseWrapperDataset):
def __init__(self, dataset, offset):
super().__init__(dataset)
self.offset = offset
def __getitem__(self, idx):
return self.dataset[idx] + self.offset
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data import data_utils
from . import BaseWrapperDataset
class PadDataset(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
def collater(self, samples):
return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
class LeftPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=True)
class RightPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=False)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import tempfile
class PlasmaArray(object):
"""
Wrapper around numpy arrays that automatically moves the data to shared
memory upon serialization. This is particularly helpful when passing numpy
arrays through multiprocessing, so that data is not unnecessarily
duplicated or pickled.
"""
def __init__(self, array):
super().__init__()
self.array = array
self.disable = array.nbytes < 134217728 # disable for arrays <128MB
self.object_id = None
self.path = None
# variables with underscores shouldn't be pickled
self._client = None
self._server = None
self._server_tmp = None
self._plasma = None
@property
def plasma(self):
if self._plasma is None and not self.disable:
try:
import pyarrow.plasma as plasma
self._plasma = plasma
except ImportError:
self._plasma = None
return self._plasma
def start_server(self):
if self.plasma is None or self._server is not None:
return
assert self.object_id is None
assert self.path is None
self._server_tmp = tempfile.NamedTemporaryFile()
self.path = self._server_tmp.name
self._server = subprocess.Popen(
[
"plasma_store",
"-m",
str(int(1.05 * self.array.nbytes)),
"-s",
self.path,
]
)
@property
def client(self):
if self._client is None:
assert self.path is not None
self._client = self.plasma.connect(self.path)
return self._client
def __getstate__(self):
if self.plasma is None:
return self.__dict__
if self.object_id is None:
self.start_server()
self.object_id = self.client.put(self.array)
state = self.__dict__.copy()
del state["array"]
state["_client"] = None
state["_server"] = None
state["_server_tmp"] = None
state["_plasma"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
if self.plasma is None:
return
self.array = self.client.get(self.object_id)
def __del__(self):
if self._server is not None:
self._server.kill()
self._server = None
self._server_tmp.close()
self._server_tmp = None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class PrependDataset(BaseWrapperDataset):
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
super().__init__(dataset)
self.prepend_getter = prepend_getter
self.ensure_first_token = ensure_first_token_is
def __getitem__(self, idx):
item = self.dataset[idx]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item
assert self.ensure_first_token is None or src[0] == self.ensure_first_token
prepend_idx = self.prepend_getter(self.dataset, idx)
assert isinstance(prepend_idx, int)
src[0] = prepend_idx
item = tuple((src,) + item[1:]) if is_tuple else src
return item
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class PrependTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item.new([self.token]), item])
return item
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n
def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class RawLabelDataset(FairseqDataset):
def __init__(self, labels):
super().__init__()
self.labels = labels
def __getitem__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def collater(self, samples):
return torch.tensor(samples)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ReplaceDataset(BaseWrapperDataset):
"""Replaces tokens found in the dataset by a specified replacement token
Args:
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
as many as the number of objects returned by the underlying dataset __getitem__ method.
"""
def __init__(self, dataset, replace_map, offsets):
super().__init__(dataset)
assert len(replace_map) > 0
self.replace_map = replace_map
self.offsets = offsets
def __getitem__(self, index):
item = self.dataset[index]
is_tuple = isinstance(item, tuple)
srcs = item if is_tuple else [item]
for offset, src in zip(self.offsets, srcs):
for k, v in self.replace_map.items():
src_off = src[offset:] if offset >= 0 else src[:offset]
src_off.masked_fill_(src_off == k, v)
item = srcs if is_tuple else srcs[0]
return item
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from fairseq.data import BaseWrapperDataset, plasma_utils
logger = logging.getLogger(__name__)
class ResamplingDataset(BaseWrapperDataset):
"""Randomly samples from a given dataset at each epoch.
Sampling is done with or without replacement, depending on the "replace"
parameter.
Optionally, the epoch size can be rescaled. This is potentially desirable
to increase per-epoch coverage of the base dataset (since sampling with
replacement means that many items in the dataset will be left out). In the
case of sampling without replacement, size_ratio should be strictly less
than 1.
Args:
dataset (~torch.utils.data.Dataset): dataset on which to sample.
weights (List[float]): list of probability weights
(default: None, which corresponds to uniform sampling).
replace (bool): sampling mode; True for "with replacement", or False
for "without replacement" (default: True)
size_ratio (float): the ratio to subsample to; must be positive
(default: 1.0).
batch_by_size (bool): whether or not to batch by sequence length
(default: True).
seed (int): RNG seed to use (default: 0).
epoch (int): starting epoch number (default: 1).
"""
def __init__(
self,
dataset,
weights=None,
replace=True,
size_ratio=1.0,
batch_by_size=True,
seed=0,
epoch=1,
):
super().__init__(dataset)
if weights is None:
self.weights = None
else:
assert len(weights) == len(dataset)
weights_arr = np.array(weights, dtype=np.float64)
weights_arr /= weights_arr.sum()
self.weights = plasma_utils.PlasmaArray(weights_arr)
self.replace = replace
assert size_ratio > 0.0
if not self.replace:
assert size_ratio < 1.0
self.size_ratio = float(size_ratio)
self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
self.batch_by_size = batch_by_size
self.seed = seed
self._cur_epoch = None
self._cur_indices = None
self.set_epoch(epoch)
def __getitem__(self, index):
return self.dataset[self._cur_indices.array[index]]
def __len__(self):
return self.actual_size
@property
def sizes(self):
if isinstance(self.dataset.sizes, list):
return [s[self._cur_indices.array] for s in self.dataset.sizes]
return self.dataset.sizes[self._cur_indices.array]
def num_tokens(self, index):
return self.dataset.num_tokens(self._cur_indices.array[index])
def size(self, index):
return self.dataset.size(self._cur_indices.array[index])
def ordered_indices(self):
if self.batch_by_size:
order = [
np.arange(len(self)),
self.sizes,
] # No need to handle `self.shuffle == True`
return np.lexsort(order)
else:
return np.arange(len(self))
def prefetch(self, indices):
self.dataset.prefetch(self._cur_indices.array[indices])
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch):
logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
super().set_epoch(epoch)
if epoch == self._cur_epoch:
return
self._cur_epoch = epoch
# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.
rng = np.random.RandomState(
[
42, # magic number
self.seed % (2 ** 32), # global seed
self._cur_epoch, # epoch index
]
)
self._cur_indices = plasma_utils.PlasmaArray(
rng.choice(
len(self.dataset),
self.actual_size,
replace=self.replace,
p=(None if self.weights is None else self.weights.array),
)
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class RollDataset(BaseWrapperDataset):
def __init__(self, dataset, shifts):
super().__init__(dataset)
self.shifts = shifts
def __getitem__(self, index):
item = self.dataset[index]
return torch.roll(item, self.shifts)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import numpy as np
from . import FairseqDataset
class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args:
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
:class:`~fairseq.data.FairseqDataset` instances.
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
"""
def __init__(self, datasets, eval_key=None):
super().__init__()
assert isinstance(datasets, OrderedDict)
self.datasets = datasets
self.eval_key = eval_key
self.longest_dataset = None
self.longest_dataset_key = None
for key, dataset in datasets.items():
assert isinstance(dataset, FairseqDataset)
if self.longest_dataset is None or len(dataset) > len(self.longest_dataset):
self.longest_dataset = dataset
self.longest_dataset_key = key
self._ordered_indices = None
def _map_index(self, key, index):
assert (
self._ordered_indices is not None
), "Must call RoundRobinZipDatasets.ordered_indices() first"
return self._ordered_indices[key][index % len(self.datasets[key])]
def __getitem__(self, index):
if self.eval_key is None:
return OrderedDict(
[
(key, dataset[self._map_index(key, index)])
for key, dataset in self.datasets.items()
]
)
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
def __len__(self):
return len(self.longest_dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return None
if self.eval_key is None:
return OrderedDict(
[
(key, dataset.collater([sample[key] for sample in samples]))
for key, dataset in self.datasets.items()
]
)
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key].collater(samples)
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
# TODO make it configurable whether to use max() or sum() here
return max(
dataset.num_tokens(self._map_index(key, index))
for key, dataset in self.datasets.items()
)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return {
key: dataset.size(self._map_index(key, index))
for key, dataset in self.datasets.items()
}
def ordered_indices(self):
"""Ordered indices for batching."""
if self._ordered_indices is None:
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying dataset directly.
self._ordered_indices = OrderedDict(
[
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
@property
def supports_prefetch(self):
return all(
getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from fairseq.data import data_utils
from . import BaseWrapperDataset
class TruncateDataset(BaseWrapperDataset):
"""Truncate a sequence by returning the first truncation_length tokens"""
def __init__(self, dataset, truncation_length):
super().__init__(dataset)
assert truncation_length is not None
self.truncation_length = truncation_length
self.dataset = dataset
def __getitem__(self, index):
item = self.dataset[index]
item_len = item.size(0)
if item_len > self.truncation_length:
item = item[: self.truncation_length]
return item
@property
def sizes(self):
return np.minimum(self.dataset.sizes, self.truncation_length)
def __len__(self):
return len(self.dataset)
class RandomCropDataset(TruncateDataset):
"""Truncate a sequence by returning a random crop of truncation_length tokens"""
def __init__(self, dataset, truncation_length, seed=1):
super().__init__(dataset, truncation_length)
self.seed = seed
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the crop changes, not item sizes
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
item_len = item.size(0)
excess = item_len - self.truncation_length
if excess > 0:
start_idx = np.random.randint(0, excess)
item = item[start_idx : start_idx + self.truncation_length]
return item
def maybe_shorten_dataset(
dataset,
split,
shorten_data_split_list,
shorten_method,
tokens_per_sample,
seed,
):
truncate_split = (
split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
)
if shorten_method == "truncate" and truncate_split:
dataset = TruncateDataset(dataset, tokens_per_sample)
elif shorten_method == "random_crop" and truncate_split:
dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
return dataset
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from . import BaseWrapperDataset
class SortDataset(BaseWrapperDataset):
def __init__(self, dataset, sort_order):
super().__init__(dataset)
if not isinstance(sort_order, (list, tuple)):
sort_order = [sort_order]
self.sort_order = sort_order
assert all(len(so) == len(dataset) for so in sort_order)
def ordered_indices(self):
return np.lexsort(self.sort_order)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class StripTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, id_to_strip):
super().__init__(dataset)
self.id_to_strip = id_to_strip
def __getitem__(self, index):
item = self.dataset[index]
while len(item) > 0 and item[-1] == self.id_to_strip:
item = item[:-1]
while len(item) > 0 and item[0] == self.id_to_strip:
item = item[1:]
return item
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from . import BaseWrapperDataset
logger = logging.getLogger(__name__)
class SubsampleDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
Args:
dataset (~torch.utils.data.Dataset): dataset to subsample
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
"""
def __init__(self, dataset, size_ratio, shuffle=False):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
self.indices = np.random.choice(
list(range(len(self.dataset))), self.actual_size, replace=False
)
self.shuffle = shuffle
logger.info(
"subsampled dataset from {} to {} (ratio={})".format(
len(self.dataset), self.actual_size, size_ratio
)
)
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return self.actual_size
def collater(self, samples):
return self.dataset.collater(samples)
@property
def sizes(self):
return self.dataset.sizes[self.indices]
@property
def name(self):
return self.dataset.name
def num_tokens(self, index):
return self.dataset.num_tokens(self.indices[index])
def size(self, index):
return self.dataset.size(self.indices[index])
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
def prefetch(self, indices):
self.dataset.prefetch(self.indices[indices])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from fairseq.data import FairseqDataset, plasma_utils
class TokenBlockDataset(FairseqDataset):
"""Break a Dataset of tokens into blocks.
Args:
dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'complete_doc': similar to 'complete' mode, but do not
cross document boundaries
- 'eos': each block contains one sentence (block_size is ignored)
include_targets (bool, optional): return next tokens as targets
(default: False).
document_sep_len (int, optional): document separator size (required for
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""
def __init__(
self,
dataset,
sizes,
block_size,
pad,
eos,
break_mode=None,
include_targets=False,
document_sep_len=1,
):
try:
from fairseq.data.token_block_utils_fast import (
_get_slice_indices_fast,
_get_block_to_dataset_index_fast,
)
except ImportError:
raise ImportError(
"Please build Cython components with: `pip install --editable .` "
"or `python setup.py build_ext --inplace`"
)
super().__init__()
self.dataset = dataset
self.pad = pad
self.eos = eos
self.include_targets = include_targets
assert len(dataset) == len(sizes)
assert len(dataset) > 0
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
if torch.is_tensor(sizes):
sizes = sizes.numpy()
sizes = sizes.astype(np.int64)
break_mode = break_mode if break_mode is not None else "none"
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0
slice_indices = _get_slice_indices_fast(
sizes, str(break_mode), block_size, document_sep_len
)
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices
if break_mode == "eos":
# much faster version for eos break mode
block_to_dataset_index = np.stack(
[
np.arange(len(sizes)), # starting index in dataset
np.zeros(
len(sizes), dtype=np.long
), # starting offset within starting index
np.arange(len(sizes)), # ending index in dataset
],
1,
)
else:
block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes,
slice_indices,
)
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
@property
def slice_indices(self):
return self._slice_indices.array
@property
def sizes(self):
return self._sizes.array
@property
def block_to_dataset_index(self):
return self._block_to_dataset_index.array
def attr(self, attr: str, index: int):
start_ds_idx, _, _ = self.block_to_dataset_index[index]
return self.dataset.attr(attr, start_ds_idx)
def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat(
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
)
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets:
# *target* is the original sentence (=item)
# *source* is shifted right by 1 (maybe left-padded with eos)
# *past_target* is shifted right by 2 (left-padded as needed)
if s == 0:
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
past_target = torch.cat(
[item.new([self.pad, self.eos]), buffer[0 : e - 2]]
)
else:
source = buffer[s - 1 : e - 1]
if s == 1:
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
else:
past_target = buffer[s - 2 : e - 2]
return source, item, past_target
return item
def __len__(self):
return len(self.slice_indices)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(
{
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment