Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("utterance_cmvn")
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
_config.get("norm_means", True),
_config.get("norm_vars", True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
return (
self.__class__.__name__
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
)
def __call__(self, x):
mean = x.mean(axis=0)
square_sums = (x ** 2).sum(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset
logger = logging.getLogger(__name__)
class RawAudioDataset(FairseqDataset):
def __init__(
self,
sample_rate,
max_sample_size=None,
min_sample_size=None,
shuffle=True,
min_length=0,
pad=False,
normalize=False,
):
super().__init__()
self.sample_rate = sample_rate
self.sizes = []
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.min_sample_size = min_sample_size
self.min_length = min_length
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
def __getitem__(self, index):
raise NotImplementedError()
def __len__(self):
return len(self.sizes)
def postprocess(self, feats, curr_sample_rate):
if feats.dim() == 2:
feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate:
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
assert feats.dim() == 1, feats.dim()
if self.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end]
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
sources = [s["source"] for s in samples]
sizes = [len(s) for s in sources]
if self.pad:
target_size = min(max(sizes), self.max_sample_size)
else:
target_size = min(min(sizes), self.max_sample_size)
collated_sources = sources[0].new_zeros(len(sources), target_size)
padding_mask = (
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
if diff == 0:
collated_sources[i] = source
elif diff < 0:
assert self.pad
collated_sources[i] = torch.cat(
[source, source.new_full((-diff,), 0.0)]
)
padding_mask[i, diff:] = True
else:
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
if self.pad:
input["padding_mask"] = padding_mask
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}
def num_tokens(self, index):
return self.size(index)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if self.pad:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)[::-1]
class FileAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
max_sample_size=None,
min_sample_size=None,
shuffle=True,
min_length=0,
pad=False,
normalize=False,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
min_length=min_length,
pad=pad,
normalize=normalize,
)
self.fnames = []
skipped = 0
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for line in f:
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_length is not None and sz < min_length:
skipped += 1
continue
self.fnames.append(items[0])
self.sizes.append(sz)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
def __getitem__(self, index):
import soundfile as sf
fname = os.path.join(self.root_dir, self.fnames[index])
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import io
import logging
import os.path as op
import re
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from fairseq.data import (
ConcatDataset,
Dictionary,
FairseqDataset,
ResamplingDataset,
data_utils as fairseq_data_utils,
)
from fairseq.data.audio.audio_utils import get_fbank, get_waveform
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
logger = logging.getLogger(__name__)
class S2TDataConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
print("Please install PyYAML to load YAML files for " "S2T data config")
self.config = {}
if op.isfile(yaml_path):
try:
with open(yaml_path) as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
logger.info(f"Failed to load config from {yaml_path}: {e}")
else:
logger.info(f"Cannot find {yaml_path}")
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", "dict.txt")
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
return self.config.get("shuffle", False)
@property
def pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("pre_tokenizer", {"tokenizer": None})
@property
def bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("bpe_tokenizer", None)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
@property
def input_channels(self):
"""The number of channels in the input audio"""
return self.config.get("input_channels", 1)
@property
def sampling_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return self.config.get("sampling_alpha", 1.0)
@property
def use_audio_input(self):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return self.config.get("use_audio_input", False)
@property
def audio_root(self):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return self.config.get("audio_root", "")
def get_feature_transforms(self, split, is_train):
"""Split-specific feature transforms. Allowing train set wildcard `_train`,
evaluation set wildcard `_eval` and general wildcard `*` for matching."""
from copy import deepcopy
cfg = deepcopy(self.config)
_cur = cfg.get("transforms", {})
cur = _cur.get(split)
cur = _cur.get("_train") if cur is None and is_train else cur
cur = _cur.get("_eval") if cur is None and not is_train else cur
cur = _cur.get("*") if cur is None else cur
cfg["transforms"] = cur
return cfg
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def is_flac_or_wav_data(data: bytes) -> bool:
is_flac = data[0] == 102 and data[1] == 76
is_wav = data[0] == 82 and data[1] == 73
return is_flac or is_wav
def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes:
with open(file_path, "rb") as f:
f.seek(offset)
data = f.read(file_size)
return data
def get_features_from_npy_or_audio(path):
ext = op.splitext(op.basename(path))[1]
if ext not in {".npy", ".flac", ".wav"}:
raise ValueError(f'Unsupported file format for "{path}"')
return np.load(path) if ext == ".npy" else get_fbank(path)
def get_features_or_waveform_from_uncompressed_zip(
path, byte_offset, byte_size, need_waveform=False
):
assert path.endswith(".zip")
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
features_or_waveform = np.load(f)
elif is_flac_or_wav_data(data):
features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f)
else:
raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform
def get_features_or_waveform(path: str, need_waveform=False):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path, *extra = path.split(":")
if not op.exists(_path):
raise FileNotFoundError(f"File not found: {_path}")
if len(extra) == 0:
if need_waveform:
return get_waveform(_path)
return get_features_from_npy_or_audio(_path)
elif len(extra) == 2:
extra = [int(i) for i in extra]
features_or_waveform = get_features_or_waveform_from_uncompressed_zip(
_path, extra[0], extra[1], need_waveform=need_waveform
)
else:
raise ValueError(f"Invalid path: {path}")
return features_or_waveform
def _collate_frames(
frames: List[torch.Tensor], is_audio_input: bool = False
) -> torch.Tensor:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
if is_audio_input:
out = frames[0].new_zeros((len(frames), max_len))
else:
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
class SpeechToTextDataset(FairseqDataset):
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
):
self.split, self.is_train_split = split, is_train_split
self.data_cfg = data_cfg
self.audio_paths, self.n_frames = audio_paths, n_frames
self.n_samples = len(audio_paths)
assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples
assert speakers is None or len(speakers) == self.n_samples
assert src_langs is None or len(src_langs) == self.n_samples
assert tgt_langs is None or len(tgt_langs) == self.n_samples
assert ids is None or len(ids) == self.n_samples
assert (tgt_dict is None and tgt_texts is None) or (
tgt_dict is not None and tgt_texts is not None
)
self.tgt_dict = tgt_dict
self.check_tgt_lang_tag()
self.src_texts, self.tgt_texts = src_texts, tgt_texts
self.src_langs, self.tgt_langs = src_langs, tgt_langs
self.ids = ids
self.shuffle = data_cfg.shuffle if is_train_split else False
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
self.data_cfg.get_feature_transforms(split, is_train_split)
)
self.pre_tokenizer = pre_tokenizer
self.bpe_tokenizer = bpe_tokenizer
logger.info(self.__repr__())
def __repr__(self):
return (
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples}, '
f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
)
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.data_cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
def tokenize_text(self, text: str):
if self.pre_tokenizer is not None:
text = self.pre_tokenizer.encode(text)
if self.bpe_tokenizer is not None:
text = self.bpe_tokenizer.encode(text)
return text
def __getitem__(
self, index: int
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]:
source = get_features_or_waveform(
self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input
)
if self.feature_transforms is not None:
assert not self.data_cfg.use_audio_input
source = self.feature_transforms(source)
source = torch.from_numpy(source).float()
target = None
if self.tgt_texts is not None:
tokenized = self.tokenize_text(self.tgt_texts[index])
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
return index, source, target
def __len__(self):
return self.n_samples
def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long)
frames = _collate_frames(
[s for _, s, _ in samples], self.data_cfg.use_audio_input
)
# sort samples by descending number of frames
n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[t for _, _, t in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, order)
target_lengths = torch.tensor(
[t.size(0) for _, _, t in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[t for _, _, t in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t in samples)
out = {
"id": indices,
"net_input": {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
},
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
return out
def num_tokens(self, index):
return self.n_frames[index]
def size(self, index):
t_len = 0
if self.tgt_texts is not None:
tokenized = self.tokenize_text(self.tgt_texts[index])
t_len = len(tokenized.split(" "))
return self.n_frames[index], t_len
@property
def sizes(self):
return np.array(self.n_frames)
@property
def can_reuse_epoch_itr_across_epochs(self):
return True
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
# first by descending order of # of frames then by original/random order
order.append([-n for n in self.n_frames])
return np.lexsort(order)
def prefetch(self, indices):
raise False
class SpeechToTextDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text"
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[List[Dict]],
data_cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
) -> SpeechToTextDataset:
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
speakers, src_langs, tgt_langs = [], [], []
for s in samples:
ids.extend([ss[cls.KEY_ID] for ss in s])
audio_paths.extend(
[op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
)
n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s])
tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s])
src_texts.extend(
[ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
)
speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s])
src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s])
tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s])
return SpeechToTextDataset(
split_name,
is_train_split,
data_cfg,
audio_paths,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
)
@classmethod
def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0):
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
_sizes = np.array(sizes)
prob = _sizes / _sizes.sum()
smoothed_prob = prob ** alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
size_ratio = (smoothed_prob * _sizes.sum()) / _sizes
o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)})
logger.info(f"original sampling probability: {o_str}")
p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)})
logger.info(f"balanced sampling probability: {p_str}")
sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)})
logger.info(f"balanced sampling size ratio: {sr_str}")
return size_ratio.tolist()
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2TDataConfig,
splits: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
) -> SpeechToTextDataset:
samples = []
_splits = splits.split(",")
for split in _splits:
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
samples.append([dict(e) for e in reader])
assert len(samples) > 0
datasets = [
cls._from_list(
name,
is_train_split,
[s],
data_cfg,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
)
for name, s in zip(_splits, samples)
]
if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls._get_size_ratios(
_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha
)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for d, r in zip(datasets, size_ratios)
]
return ConcatDataset(datasets)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from . import FairseqDataset
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
"""Backtranslate a list of samples.
Given an input (*samples*) of the form:
[{'id': 1, 'source': 'hallo welt'}]
this will return:
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
Args:
samples (List[dict]): samples to backtranslate. Individual samples are
expected to have a 'source' key, which will become the 'target'
after backtranslation.
collate_fn (callable): function to collate samples into a mini-batch
generate_fn (callable): function to generate backtranslations
cuda (bool): use GPU for generation (default: ``True``)
Returns:
List[dict]: an updated list of samples with a backtranslated source
"""
collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s)
id_to_src = {sample["id"]: sample["source"] for sample in samples}
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
{
"id": id.item(),
"target": id_to_src[id.item()],
"source": hypos[0]["tokens"].cpu(),
}
for id, hypos in zip(collated_samples["id"], generated_sources)
]
class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
sentences.
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
sentences to be backtranslated.
backtranslation_fn (callable, optional): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
Pass in None when it is not available at initialization time, and
use set_backtranslation_fn function to set it when available.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__(
self,
tgt_dataset,
src_dict,
tgt_dict=None,
backtranslation_fn=None,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.output_collater = (
output_collater if output_collater is not None else tgt_dataset.collater
)
self.cuda = cuda if torch.cuda.is_available() else False
self.src_dict = src_dict
self.tgt_dict = tgt_dict
def __getitem__(self, index):
"""
Returns a single sample from *tgt_dataset*. Note that backtranslation is
not applied in this step; use :func:`collater` instead to backtranslate
a batch of samples.
"""
return self.tgt_dataset[index]
def __len__(self):
return len(self.tgt_dataset)
def set_backtranslation_fn(self, backtranslation_fn):
self.backtranslation_fn = backtranslation_fn
def collater(self, samples):
"""Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from *tgt_dataset*, load a collated target sample to
feed to the backtranslation model. Then take the backtranslation with
the best score as the source and the original input as the target.
Note: we expect *tgt_dataset* to provide a function `collater()` that
will collate samples into the format expected by *backtranslation_fn*.
After backtranslation, we will feed the new list of samples (i.e., the
`(backtranslated source, original source)` pairs) to *output_collater*
and return the result.
Args:
samples (List[dict]): samples to backtranslate and collate
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
if samples[0].get("is_dummy", False):
return samples
samples = backtranslate_samples(
samples=samples,
collate_fn=self.tgt_dataset.collater,
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
cuda=self.cuda,
)
return self.output_collater(samples)
def num_tokens(self, index):
"""Just use the tgt dataset num_tokens"""
return self.tgt_dataset.num_tokens(index)
def ordered_indices(self):
"""Just use the tgt dataset ordered_indices"""
return self.tgt_dataset.ordered_indices()
def size(self, index):
"""Return an example's size as a float or tuple. This value is used
when filtering a dataset with ``--max-positions``.
Note: we use *tgt_dataset* to approximate the length of the source
sentence, since we do not know the actual length until after
backtranslation.
"""
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
@property
def supports_prefetch(self):
return getattr(self.tgt_dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class BaseWrapperDataset(FairseqDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@property
def sizes(self):
return self.dataset.sizes
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
def prefetch(self, indices):
self.dataset.prefetch(indices)
def get_batch_shapes(self):
return self.dataset.get_batch_shapes()
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
return self.dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
def filter_indices_by_size(self, indices, max_sizes):
return self.dataset.filter_indices_by_size(indices, max_sizes)
@property
def can_reuse_epoch_itr_across_epochs(self):
return self.dataset.can_reuse_epoch_itr_across_epochs
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch.nn.functional as F
from fairseq.data import BaseWrapperDataset
class BucketPadLengthDataset(BaseWrapperDataset):
"""
Bucket and pad item lengths to the nearest bucket size. This can be used to
reduce the number of unique batch shapes, which is important on TPUs since
each new batch shape requires a recompilation.
Args:
dataset (FairseqDatset): dataset to bucket
sizes (List[int]): all item sizes
num_buckets (int): number of buckets to create
pad_idx (int): padding symbol
left_pad (bool): if True, pad on the left; otherwise right pad
"""
def __init__(
self,
dataset,
sizes,
num_buckets,
pad_idx,
left_pad,
):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
assert num_buckets > 0
self.buckets = np.unique(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation="lower",
)[1:]
)
def get_bucketed_sizes(orig_sizes, buckets):
sizes = np.copy(orig_sizes)
assert np.min(sizes) >= 0
start_val = -1
for end_val in buckets:
mask = (sizes > start_val) & (sizes <= end_val)
sizes[mask] = end_val
start_val = end_val
return sizes
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
def __getitem__(self, index):
item = self.dataset[index]
bucket_size = self._bucketed_sizes[index]
num_pad = bucket_size - item.size(-1)
return F.pad(
item,
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
value=self.pad_idx,
)
@property
def sizes(self):
return self._bucketed_sizes
def num_tokens(self, index):
return self._bucketed_sizes[index]
def size(self, index):
return self._bucketed_sizes[index]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class ColorizeDataset(BaseWrapperDataset):
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
def __init__(self, dataset, color_getter):
super().__init__(dataset)
self.color_getter = color_getter
def collater(self, samples):
base_collate = super().collater(samples)
if len(base_collate) > 0:
base_collate["net_input"]["colors"] = torch.tensor(
list(self.color_getter(self.dataset, s["id"]) for s in samples),
dtype=torch.long,
)
return base_collate
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import numpy as np
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence, sample_ratios):
r, s = [], 0
for e, ratio in zip(sequence, sample_ratios):
curr_len = int(ratio * len(e))
r.append(curr_len + s)
s += curr_len
return r
def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets)
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx][sample_idx]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return dataset_idx, sample_idx
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
if hasattr(self.datasets[0], "collater"):
return self.datasets[0].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property
def sizes(self):
_dataset_sizes = []
for ds, sr in zip(self.datasets, self.sample_ratios):
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(np.tile(ds.sizes, sr))
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
# special handling for concatenating lang_pair_datasets
indices = np.arange(len(self))
sizes = self.sizes
tgt_sizes = (
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
)
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
# sort by target length, then source length
if tgt_sizes is not None:
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
else:
return np.argsort(self.sizes)
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class ConcatSentencesDataset(FairseqDataset):
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
assert all(
len(ds) == len(datasets[0]) for ds in datasets
), "datasets must have the same length"
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
def __len__(self):
return len(self.datasets[0])
def collater(self, samples):
return self.datasets[0].collater(samples)
@property
def sizes(self):
return sum(ds.sizes for ds in self.datasets)
def num_tokens(self, index):
return sum(ds.num_tokens(index) for ds in self.datasets)
def size(self, index):
return sum(ds.size(index) for ds in self.datasets)
def ordered_indices(self):
return self.datasets[0].ordered_indices()
@property
def supports_prefetch(self):
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
def prefetch(self, indices):
for ds in self.datasets:
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
import contextlib
import itertools
import logging
import os
import warnings
from typing import Optional, Tuple
import numpy as np
import torch
logger = logging.getLogger(__name__)
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in os.listdir(path):
parts = filename.split(".")
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
return parts[1].split("-")
return src, dst
def collate_tokens(
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def load_indexed_dataset(
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
):
"""A helper function for loading indexed datasets.
Args:
path (str): path to indexed dataset (e.g., 'data-bin/train')
dictionary (~fairseq.data.Dictionary): data dictionary
dataset_impl (str, optional): which dataset implementation to use. If
not provided, it will be inferred automatically. For legacy indexed
data we use the 'cached' implementation by default.
combine (bool, optional): automatically load and combine multiple
datasets. For example, if *path* is 'data-bin/train', then we will
combine 'data-bin/train', 'data-bin/train1', ... and return a
single ConcatDataset instance.
"""
from fairseq.data.concat_dataset import ConcatDataset
import fairseq.data.indexed_dataset as indexed_dataset
datasets = []
for k in itertools.count():
path_k = path + (str(k) if k > 0 else "")
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
dataset_impl_k = dataset_impl
if dataset_impl_k is None:
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
dataset = indexed_dataset.make_dataset(
path_k,
impl=dataset_impl_k or default,
fix_lua_indexing=True,
dictionary=dictionary,
)
if dataset is None:
break
logger.info("loaded {} examples from: {}".format(len(dataset), path_k))
datasets.append(dataset)
if not combine:
break
if len(datasets) == 0:
return None
elif len(datasets) == 1:
return datasets[0]
else:
return ConcatDataset(datasets)
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
if len(addl_seeds) > 0:
seed = int(hash((seed, *addl_seeds)) % 1e6)
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def collect_filtered(function, iterable, filtered):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for el in iterable:
if function(el):
yield el
else:
filtered.append(el)
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
def compare_leq(a, b):
return a <= b if not isinstance(a, tuple) else max(a) <= b
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions
elif isinstance(max_positions, dict):
idx_size = size_fn(idx)
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
all(
a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key])
)
for key in intersect_keys
)
else:
# Hacky as heck, for the specific case of multilingual training with RoundRobin.
if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
return all(
a is None or b is None or compare_leq(a, b)
for a, b in zip(size_fn(idx).values(), max_positions)
)
# For MultiCorpusSampledDataset, will generalize it later
if not isinstance(size_fn(idx), Iterable):
return all(size_fn(idx) <= b for b in max_positions)
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)
)
ignored = []
itr = collect_filtered(check_size, indices, ignored)
indices = np.fromiter(itr, dtype=np.int64, count=-1)
return indices, ignored
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
"""
[deprecated] Filter indices based on their size.
Use `FairseqDataset::filter_indices_by_size` instead.
Args:
indices (List[int]): ordered list of dataset indices
dataset (FairseqDataset): fairseq dataset instance
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
warnings.warn(
"data_utils.filter_by_size is deprecated. "
"Use `FairseqDataset::filter_indices_by_size` instead.",
stacklevel=2,
)
if isinstance(max_positions, float) or isinstance(max_positions, int):
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
indices = indices[dataset.sizes[indices] <= max_positions]
elif (
hasattr(dataset, "sizes")
and isinstance(dataset.sizes, list)
and len(dataset.sizes) == 1
):
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
indices = indices[dataset.sizes[0][indices] <= max_positions]
else:
indices, ignored = _filter_by_size_dynamic(
indices, dataset.size, max_positions
)
else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
if len(ignored) > 0 and raise_exception:
raise Exception(
(
"Size of sample #{} is invalid (={}) since max_positions={}, "
"skip this example with --skip-invalid-size-inputs-valid-test"
).format(ignored[0], dataset.size(ignored[0]), max_positions)
)
if len(ignored) > 0:
logger.warning(
(
"{} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).format(len(ignored), max_positions, ignored[:10])
)
return indices
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if max_sizes is None:
return indices, []
if type(max_sizes) in (int, float):
max_src_size, max_tgt_size = max_sizes, max_sizes
else:
max_src_size, max_tgt_size = max_sizes
if tgt_sizes is None:
ignored = indices[src_sizes[indices] > max_src_size]
else:
ignored = indices[
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
]
if len(ignored) > 0:
if tgt_sizes is None:
indices = indices[src_sizes[indices] <= max_src_size]
else:
indices = indices[
(src_sizes[indices] <= max_src_size)
& (tgt_sizes[indices] <= max_tgt_size)
]
return indices, ignored.tolist()
def batch_by_size(
indices,
num_tokens_fn,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
fixed_shapes=None,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be less than N or a multiple of N (default: 1).
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
only be created with the given shapes. *max_sentences* and
*required_batch_size_multiple* will be ignored (default: None).
"""
try:
from fairseq.data.data_utils_fast import (
batch_by_size_fast,
batch_fixed_shapes_fast,
)
except ImportError:
raise ImportError(
"Please build Cython components with: `pip install --editable .` "
"or `python setup.py build_ext --inplace`"
)
max_tokens = max_tokens if max_tokens is not None else -1
max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple
if not isinstance(indices, np.ndarray):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
if fixed_shapes is None:
return batch_by_size_fast(
indices,
num_tokens_fn,
max_tokens,
max_sentences,
bsz_mult,
)
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
sort_order = np.lexsort(
[
fixed_shapes[:, 1].argsort(), # length
fixed_shapes[:, 0].argsort(), # bsz
]
)
fixed_shapes_sorted = fixed_shapes[sort_order]
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol is not None and symbol != "none":
sentence = (sentence + " ").replace(symbol, "").rstrip()
return sentence
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
def get_mem_usage():
try:
import psutil
mb = 1024 * 1024
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
except ImportError:
return "N/A"
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor:
return ~lengths_to_padding_mask(lens)
This source diff could not be displayed because it is too large. You can view the blob instead.
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
ctypedef int64_t DTYPE_t
cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences):
if num_sentences == 0:
return 0
if max_sentences > 0 and num_sentences == max_sentences:
return 1
if max_tokens > 0 and num_tokens > max_tokens:
return 1
return 0
@cython.cdivision(True)
cpdef list batch_by_size_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
assert max_tokens <= 0 or sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
int64_t num_sentences,
int64_t num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
return i
return -1
@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
if shape_idx == -1:
batches.append(batch)
batch = []
sample_lens = []
sample_len = 0
shapes_view = fixed_shapes_sorted
elif shape_idx > 0:
# small optimization for the next call to _find_valid_shape
shapes_view = shapes_view[shape_idx:]
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from . import FairseqDataset, data_utils
def collate(
samples,
pad_idx,
eos_idx,
vocab,
left_pad_source=False,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
):
assert input_feeding
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
left_pad=left_pad,
move_eos_to_beginning=move_eos_to_beginning,
pad_to_length=pad_to_length,
)
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
ntokens = sum(len(s["target"]) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"nsentences": samples[0]["source"].size(0),
"sort_order": sort_order,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
class DenoisingDataset(FairseqDataset):
"""
A wrapper around TokenBlockDataset for BART dataset.
Args:
dataset (TokenBlockDataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
mask_idx (int): dictionary index used for masked token
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
args: argparse arguments.
"""
def __init__(
self,
dataset,
sizes,
vocab,
mask_idx,
mask_whole_words,
shuffle,
seed,
args,
eos=None,
item_transform_func=None,
):
self.dataset = dataset
self.sizes = sizes
self.vocab = vocab
self.shuffle = shuffle
self.seed = seed
self.mask_idx = mask_idx
self.mask_whole_word = mask_whole_words
self.mask_ratio = args.mask
self.random_ratio = args.mask_random
self.insert_ratio = args.insert
self.rotate_ratio = args.rotate
self.permute_sentence_ratio = args.permute_sentences
self.eos = eos if eos is not None else vocab.eos()
self.item_transform_func = item_transform_func
if args.bpe != "gpt2":
self.full_stop_index = self.vocab.eos()
else:
assert args.bpe == "gpt2"
self.full_stop_index = self.vocab.index("13")
self.replace_length = args.replace_length
if self.replace_length not in [-1, 0, 1]:
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
if args.mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
raise ValueError(f"if using subwords, use replace-length=1 or 0")
self.mask_span_distribution = None
if args.mask_length == "span-poisson":
_lambda = args.poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
tokens = self.dataset[index]
assert tokens[-1] == self.eos
source, target = tokens, tokens.clone()
if self.permute_sentence_ratio > 0.0:
source = self.permute_sentences(source, self.permute_sentence_ratio)
if self.mask_ratio > 0:
source = self.add_whole_word_mask(source, self.mask_ratio)
if self.insert_ratio > 0:
source = self.add_insertion_noise(source, self.insert_ratio)
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
source = self.add_rolling_noise(source)
# there can additional changes to make:
if self.item_transform_func is not None:
source, target = self.item_transform_func(source, target)
assert (source >= 0).all()
assert (source[1:-1] >= 1).all()
assert (source <= len(self.vocab)).all()
assert source[0] == self.vocab.bos()
assert source[-1] == self.eos
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return len(self.dataset)
def permute_sentences(self, source, p=1.0):
full_stops = source == self.full_stop_index
# Pretend it ends with a full stop so last span is a sentence
full_stops[-2] = 1
# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
result = source.clone()
num_sentences = sentence_ends.size(0)
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
substitutions = torch.randperm(num_sentences)[:num_to_permute]
ordering = torch.arange(0, num_sentences)
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
# Ignore <bos> at start
index = 1
for i in ordering:
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
result[index : index + sentence.size(0)] = sentence
index += sentence.size(0)
return result
def word_starts(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def add_whole_word_mask(self, source, p):
is_word_start = self.word_starts(source)
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat(
[
lengths,
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
],
dim=0,
)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_permuted_noise(self, tokens, p):
num_words = len(tokens)
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
return tokens
def add_rolling_noise(self, tokens):
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
tokens = torch.cat(
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
dim=0,
)
return tokens
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
result[noise_indices[:num_random]] = torch.randint(
low=1, high=len(self.vocab), size=(num_random,)
)
result[~noise_mask] = tokens
assert (result >= 0).all()
return result
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return collate(
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, "supports_prefetch")
and self.src.supports_prefetch
and hasattr(self.tgt, "supports_prefetch")
and self.tgt.supports_prefetch
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import Counter
from multiprocessing import Pool
import torch
from fairseq import utils
from fairseq.binarizer import safe_readline
from fairseq.data import data_utils
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(
self,
tensor,
bpe_symbol=None,
escape_unk=False,
extra_symbols_to_ignore=None,
unk_string=None,
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return "\n".join(
self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
for t in tensor
)
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
extra_symbols_to_ignore.add(self.eos())
def token_string(i):
if i == self.unk():
if unk_string is not None:
return unk_string
else:
return self.unk_string(escape_unk)
else:
return self[i]
if hasattr(self, "bos_index"):
extra_symbols_to_ignore.add(self.bos())
sent = " ".join(
token_string(i)
for i in tensor
if utils.item(i) not in extra_symbols_to_ignore
)
return data_utils.post_process(sent, bpe_symbol)
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return "<{}>".format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[: self.nspecial]
new_count = self.count[: self.nspecial]
c = Counter(
dict(
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
)
)
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
self.pad_to_multiple_(padding_factor)
def pad_to_multiple_(self, padding_factor):
"""Pad Dictionary size to be a multiple of *padding_factor*."""
if padding_factor > 1:
i = 0
while len(self) % padding_factor != 0:
symbol = "madeupword{:04d}".format(i)
self.add_symbol(symbol, n=0)
i += 1
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(
"Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)
)
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
)
def _save(self, f, kv_iterator):
if isinstance(f, str):
PathManager.mkdirs(os.path.dirname(f))
with PathManager.open(f, "w", encoding="utf-8") as fd:
return self.save(fd)
for k, v in kv_iterator:
print("{} {}".format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(
f,
zip(
ex_keys + self.symbols[self.nspecial :],
ex_vals + self.count[self.nspecial :],
),
)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
def encode_line(
self,
line,
line_tokenizer=tokenize_line,
add_if_not_exist=True,
consumer=None,
append_eos=True,
reverse_order=False,
):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(
filename, tokenize, eos_word, worker_id=0, num_workers=1
):
counter = Counter()
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in sorted(counter.items()):
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(
pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers),
)
)
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(
Dictionary._add_file_to_dictionary_single_worker(
filename, tokenize, dict.eos_word
)
)
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(
wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{},
)
self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
def __len__(self):
return self.length
def __getitem__(self, i):
if i < self.length:
return self.wrapped_dict[i]
return self.wrapped_dict.unk()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
"--tokenizer",
default=None,
)
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
"--bpe",
default=None,
)
# automatically import any Python files in the encoders/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.data.encoders." + module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment