Commit 18d27e00 authored by wangwei990215's avatar wangwei990215
Browse files

initial commit

parent 541f4c7a
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from . import BaseWrapperDataset
class SortDataset(BaseWrapperDataset):
def __init__(self, dataset, sort_order):
super().__init__(dataset)
if not isinstance(sort_order, (list, tuple)):
sort_order = [sort_order]
self.sort_order = sort_order
assert all(len(so) == len(dataset) for so in sort_order)
def ordered_indices(self):
return np.lexsort(self.sort_order)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class StripTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, id_to_strip):
super().__init__(dataset)
self.id_to_strip = id_to_strip
def __getitem__(self, index):
item = self.dataset[index]
while len(item) > 0 and item[-1] == self.id_to_strip:
item = item[:-1]
while len(item) > 0 and item[0] == self.id_to_strip:
item = item[1:]
return item
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from . import BaseWrapperDataset
logger = logging.getLogger(__name__)
class SubsampleDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
Args:
dataset (~torch.utils.data.Dataset): dataset to subsample
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
"""
def __init__(self, dataset, size_ratio, shuffle=False):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
self.indices = np.random.choice(
list(range(len(self.dataset))), self.actual_size, replace=False
)
self.shuffle = shuffle
logger.info(
"subsampled dataset from {} to {} (ratio={})".format(
len(self.dataset), self.actual_size, size_ratio
)
)
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return self.actual_size
def collater(self, samples):
return self.dataset.collater(samples)
@property
def sizes(self):
return self.dataset.sizes[self.indices]
@property
def name(self):
return self.dataset.name
def num_tokens(self, index):
return self.dataset.num_tokens(self.indices[index])
def size(self, index):
return self.dataset.size(self.indices[index])
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
def prefetch(self, indices):
self.dataset.prefetch(self.indices[indices])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from fairseq.data import FairseqDataset, plasma_utils
class TokenBlockDataset(FairseqDataset):
"""Break a Dataset of tokens into blocks.
Args:
dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'complete_doc': similar to 'complete' mode, but do not
cross document boundaries
- 'eos': each block contains one sentence (block_size is ignored)
include_targets (bool, optional): return next tokens as targets
(default: False).
document_sep_len (int, optional): document separator size (required for
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""
def __init__(
self,
dataset,
sizes,
block_size,
pad,
eos,
break_mode=None,
include_targets=False,
document_sep_len=1,
):
try:
from fairseq.data.token_block_utils_fast import (
_get_slice_indices_fast,
_get_block_to_dataset_index_fast,
)
except ImportError:
raise ImportError(
"Please build Cython components with: `pip install --editable .` "
"or `python setup.py build_ext --inplace`"
)
super().__init__()
self.dataset = dataset
self.pad = pad
self.eos = eos
self.include_targets = include_targets
assert len(dataset) == len(sizes)
assert len(dataset) > 0
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
if torch.is_tensor(sizes):
sizes = sizes.numpy()
sizes = sizes.astype(np.int64)
break_mode = break_mode if break_mode is not None else "none"
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0
slice_indices = _get_slice_indices_fast(
sizes, str(break_mode), block_size, document_sep_len
)
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]
# build index mapping block indices to the underlying dataset indices
if break_mode == "eos":
# much faster version for eos break mode
block_to_dataset_index = np.stack(
[
np.arange(len(sizes)), # starting index in dataset
np.zeros(
len(sizes), dtype=np.long
), # starting offset within starting index
np.arange(len(sizes)), # ending index in dataset
],
1,
)
else:
block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes,
slice_indices,
)
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
@property
def slice_indices(self):
return self._slice_indices.array
@property
def sizes(self):
return self._sizes.array
@property
def block_to_dataset_index(self):
return self._block_to_dataset_index.array
def attr(self, attr: str, index: int):
start_ds_idx, _, _ = self.block_to_dataset_index[index]
return self.dataset.attr(attr, start_ds_idx)
def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat(
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
)
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets:
# *target* is the original sentence (=item)
# *source* is shifted right by 1 (maybe left-padded with eos)
# *past_target* is shifted right by 2 (left-padded as needed)
if s == 0:
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]])
past_target = torch.cat(
[item.new([self.pad, self.eos]), buffer[0 : e - 2]]
)
else:
source = buffer[s - 1 : e - 1]
if s == 1:
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]])
else:
past_target = buffer[s - 2 : e - 2]
return source, item, past_target
return item
def __len__(self):
return len(self.slice_indices)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(
{
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)
This source diff could not be displayed because it is too large. You can view the blob instead.
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from itertools import chain
from libc.math cimport ceil
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
DTYPE = np.int64
ctypedef int64_t DTYPE_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
cdef DTYPE_t total_size = sizes.sum()
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef DTYPE_t i
cdef DTYPE_t start
cdef DTYPE_t end
for i in range(length):
start = i * block_size
end = min(start + block_size, total_size)
slice_indices_view[i][0] = start
slice_indices_view[i][1] = end
return slice_indices
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
"""
Faster function to convert DTYPE_t list of list.
Only fast when there are huge number of rows and low number of columns.
"""
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
return flat.reshape((len(list_of_list), -1))
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
cdef DTYPE_t tok_idx = 0
cdef DTYPE_t sz_idx = 0
cdef DTYPE_t curr_size = 0
cdef DTYPE_t i = 0
cdef DTYPE_t length
cdef DTYPE_t total_size
cdef DTYPE_t[:] sizes_view = sizes
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
cdef list slice_indices_list = []
if break_mode is None or break_mode == 'none':
slice_indices = _get_slice_indices_none_mode(sizes, block_size)
elif break_mode == 'complete':
while sz_idx < len(sizes_view):
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'complete_doc':
while sz_idx < len(sizes_view):
if (
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes_view[sz_idx] != document_sep_len
):
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
# Only keep non-empty documents.
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes_view[sz_idx] == document_sep_len:
tok_idx += sizes_view[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'eos':
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
cumsum = sizes.cumsum(axis=0)
slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1]
slice_indices[:, 1] = cumsum
else:
raise ValueError('Invalid break_mode: ' + break_mode)
return slice_indices
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
cdef DTYPE_t start_ds_idx
cdef DTYPE_t start_offset
cdef DTYPE_t end_ds_idx
cdef DTYPE_t i
cdef DTYPE_t s
cdef DTYPE_t e
cdef DatasetSearcher ds = DatasetSearcher(sizes)
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef Py_ssize_t x_max = slice_indices.shape[0]
for i in range(x_max):
s = slice_indices_view[i][0]
e = slice_indices_view[i][1]
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
return block_to_dataset_index
cdef class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""
cdef DTYPE_t current_i
cdef DTYPE_t current_offset
cdef DTYPE_t current_index
cdef DTYPE_t[:] sizes
def __init__(self, DTYPE_t[:] sizes):
self.sizes = sizes
self.reset()
cdef reset(self):
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index
self.current_index = 0 # index in underlying dataset
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int step(self, DTYPE_t i):
cdef DTYPE_t to_consume
cdef DTYPE_t remaining
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining >= 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return 1
return 0
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef seek(self, DTYPE_t i):
cdef int not_done = 1
while not_done == 1:
not_done = self.step(i)
assert self.current_i == i
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to wrap
eos (int): index of the end-of-sentence symbol
append_eos_to_src (bool, optional): append EOS to the end of src
remove_eos_from_src (bool, optional): remove EOS from the end of src
append_eos_to_tgt (bool, optional): append EOS to the end of tgt
remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
"""
def __init__(
self,
dataset,
eos,
append_eos_to_src=False,
remove_eos_from_src=False,
append_eos_to_tgt=False,
remove_eos_from_tgt=False,
has_target=True,
):
if not isinstance(dataset, FairseqDataset):
raise ValueError("dataset must be an instance of FairseqDataset")
if append_eos_to_src and remove_eos_from_src:
raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src")
if append_eos_to_tgt and remove_eos_from_tgt:
raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt")
self.dataset = dataset
self.eos = torch.LongTensor([eos])
self.append_eos_to_src = append_eos_to_src
self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt
self.has_target = has_target
# precompute how we should adjust the reported sizes
self._src_delta = 0
self._src_delta += 1 if append_eos_to_src else 0
self._src_delta -= 1 if remove_eos_from_src else 0
self._tgt_delta = 0
self._tgt_delta += 1 if append_eos_to_tgt else 0
self._tgt_delta -= 1 if remove_eos_from_tgt else 0
self._checked_src = False
self._checked_tgt = False
def _check_src(self, src, expect_eos):
if not self._checked_src:
assert (src[-1] == self.eos[0]) == expect_eos
self._checked_src = True
def _check_tgt(self, tgt, expect_eos):
if self.has_target and not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
def transform(item):
if self.append_eos_to_src:
self.eos = self.eos.to(device=item["source"].device)
self._check_src(item["source"], expect_eos=False)
item["source"] = torch.cat([item["source"], self.eos])
if self.remove_eos_from_src:
self.eos = self.eos.to(device=item["source"].device)
self._check_src(item["source"], expect_eos=True)
item["source"] = item["source"][:-1]
if self.append_eos_to_tgt:
self.eos = self.eos.to(device=item["target"].device)
self._check_tgt(item["target"], expect_eos=False)
item["target"] = torch.cat([item["target"], self.eos])
if self.remove_eos_from_tgt:
self.eos = self.eos.to(device=item["target"].device)
self._check_tgt(item["target"], expect_eos=True)
item["target"] = item["target"][:-1]
return item
samples = list(map(transform, samples))
return self.dataset.collater(samples)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
if self.has_target:
src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
else:
return self.dataset.size(index)
def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the
# addition or removal of eos
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from . import FairseqDataset
class TransformEosLangPairDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
collated samples of language pair dataset.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
LanguagePairDataset schema
src_eos (int): original source end-of-sentence symbol index to be replaced
new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
beginning of 'prev_output_tokens'
"""
def __init__(
self,
dataset: FairseqDataset,
src_eos: int,
new_src_eos: Optional[int] = None,
tgt_bos: Optional[int] = None,
new_tgt_bos: Optional[int] = None,
):
self.dataset = dataset
self.src_eos = src_eos
self.new_src_eos = new_src_eos
self.tgt_bos = tgt_bos
self.new_tgt_bos = new_tgt_bos
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples, **extra_args):
samples = self.dataset.collater(samples, **extra_args)
if self.new_src_eos is not None:
if self.dataset.left_pad_source:
assert (
samples["net_input"]["src_tokens"][:, -1] != self.src_eos
).sum() == 0
samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos
else:
eos_idx = samples["net_input"]["src_lengths"] - 1
assert (
samples["net_input"]["src_tokens"][
torch.arange(eos_idx.size(0)), eos_idx
]
!= self.src_eos
).sum() == 0
eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1)
samples["net_input"]["src_tokens"].scatter_(
1, eos_idx, self.new_src_eos
)
if (
self.new_tgt_bos is not None
and "prev_output_tokens" in samples["net_input"]
):
if self.dataset.left_pad_target:
# TODO: support different padding direction on target side
raise NotImplementedError(
"TransformEosLangPairDataset does not implement --left-pad-target True option"
)
else:
assert (
samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos
).sum() == 0
samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos
return samples
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
@property
def sizes(self):
# dataset.sizes can be a dynamically computed sizes:
return self.dataset.sizes
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .utils import ChoiceEnum, FairseqDataclass
__all__ = ["FairseqDataclass", "ChoiceEnum"]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.dataclass.utils import ChoiceEnum
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"])
DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"])
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment