Commit 7143f128 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'hepj-test' into 'main'

更新transformer代码

See merge request dcutoolkit/deeplearing/dlexamples_new!47
parents a30b77fe c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
@register_bpe("bytes")
class Bytes(object):
def __init__(self, *unused):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
encoded = byte_encode(x)
escaped = encoded.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
@register_bpe("characters")
class Characters(object):
def __init__(self, *unused):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
escaped = x.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class fastBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
@register_bpe("fastbpe", dataclass=fastBPEConfig)
class fastBPE(object):
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(cfg.bpe_codes)
try:
import fastBPE
self.bpe = fastBPE.fastBPE(codes)
self.bpe_symbol = "@@ "
except ImportError:
raise ImportError("Please install fastBPE with: pip install fastBPE")
def encode(self, x: str) -> str:
return self.bpe.apply([x])[0]
def decode(self, x: str) -> str:
return (x + " ").replace(self.bpe_symbol, "").rstrip()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from .gpt2_bpe_utils import get_encoder
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@dataclass
class GPT2BPEConfig(FairseqDataclass):
gpt2_encoder_json: str = field(
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
)
gpt2_vocab_bpe: str = field(
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
)
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
class GPT2BPE(object):
def __init__(self, cfg):
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")
"""
Byte pair encoding utilities from GPT-2.
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
Original license: MIT
"""
import json
from functools import lru_cache
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
try:
import regex as re
self.re = re
except ImportError:
raise ImportError("Please install regex with: pip install regex")
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = self.re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in self.re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder.get(token, token) for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text
def get_encoder(encoder_json_path, vocab_bpe_path):
with open(encoder_json_path, "r") as f:
encoder = json.load(f)
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class BertBPEConfig(FairseqDataclass):
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
bpe_vocab_file: Optional[str] = field(
default=None, metadata={"help": "bpe vocab file"}
)
@register_bpe("bert", dataclass=BertBPEConfig)
class BertBPE(object):
def __init__(self, cfg):
try:
from transformers import BertTokenizer
except ImportError:
raise ImportError(
"Please install transformers with: pip install transformers"
)
if cfg.bpe_vocab_file:
self.bert_tokenizer = BertTokenizer(
cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
)
else:
vocab_file_name = (
"bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
)
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
def encode(self, x: str) -> str:
return " ".join(self.bert_tokenizer.tokenize(x))
def decode(self, x: str) -> str:
return self.bert_tokenizer.clean_up_tokenization(
self.bert_tokenizer.convert_tokens_to_string(x.split(" "))
)
def is_beginning_of_word(self, x: str) -> bool:
return not x.startswith("##")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from fairseq import file_utils
@dataclass
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
bpe_add_prefix_space: bool = field(
default=False, metadata={"help": "add prefix space before encoding"}
)
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
class HuggingFaceByteLevelBPE(object):
def __init__(self, cfg):
try:
from tokenizers import ByteLevelBPETokenizer
except ImportError:
raise ImportError(
"Please install huggingface/tokenizers with: " "pip install tokenizers"
)
bpe_vocab = file_utils.cached_path(cfg.bpe_vocab)
bpe_merges = file_utils.cached_path(cfg.bpe_merges)
self.bpe = ByteLevelBPETokenizer(
bpe_vocab,
bpe_merges,
add_prefix_space=cfg.bpe_add_prefix_space,
)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x).ids))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@dataclass
class MosesTokenizerConfig(FairseqDataclass):
source_lang: str = field(default="en", metadata={"help": "source language"})
target_lang: str = field(default="en", metadata={"help": "target language"})
moses_no_dash_splits: bool = field(
default=False, metadata={"help": "don't apply dash split rules"}
)
moses_no_escape: bool = field(
default=False,
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
)
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
class MosesTokenizer(object):
def __init__(self, cfg: MosesTokenizerConfig):
self.cfg = cfg
try:
from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(cfg.source_lang)
self.detok = MosesDetokenizer(cfg.target_lang)
except ImportError:
raise ImportError(
"Please install Moses tokenizer with: pip install sacremoses"
)
def encode(self, x: str) -> str:
return self.tok.tokenize(
x,
aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
return_str=True,
escape=(not self.cfg.moses_no_escape),
)
def decode(self, x: str) -> str:
return self.detok.detokenize(x.split())
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@register_tokenizer("nltk", dataclass=FairseqDataclass)
class NLTKTokenizer(object):
def __init__(self, *unused):
try:
from nltk.tokenize import word_tokenize
self.word_tokenize = word_tokenize
except ImportError:
raise ImportError("Please install nltk with: pip install nltk")
def encode(self, x: str) -> str:
return " ".join(self.word_tokenize(x))
def decode(self, x: str) -> str:
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class SentencepieceConfig(FairseqDataclass):
sentencepiece_model: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
sentencepiece_enable_sampling: bool = field(
default=False, metadata={"help": "enable sampling"}
)
sentencepiece_alpha: Optional[float] = field(
default=None,
metadata={
"help": "soothing parameter for unigram sampling, "
"and merge probability for BPE-dropout"
},
)
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
class SentencepieceBPE(object):
def __init__(self, cfg):
self.enable_sampling = cfg.sentencepiece_enable_sampling
self.alpha = cfg.sentencepiece_alpha
sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
try:
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(sentencepiece_model)
except ImportError:
raise ImportError(
"Please install sentencepiece with: pip install sentencepiece"
)
def encode(self, x: str) -> str:
return " ".join(
self.sp.Encode(
x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
)
)
def decode(self, x: str) -> str:
return x.replace(" ", "").replace("\u2581", " ").strip()
def is_beginning_of_word(self, x: str) -> bool:
if x in ["<unk>", "<s>", "</s>", "<pad>"]:
# special elements are always considered beginnings
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
# but these special tokens are also contained in the sentencepiece
# vocabulary which causes duplicate special tokens. This hack makes
# sure that they are all taken into account.
return True
return x.startswith("\u2581")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@register_tokenizer("space", dataclass=FairseqDataclass)
class SpaceTokenizer(object):
def __init__(self, *unused):
self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str:
return self.space_tok.sub(" ", x)
def decode(self, x: str) -> str:
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class SubwordNMTBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"})
bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"})
@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig)
class SubwordNMTBPE(object):
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
codes = file_utils.cached_path(cfg.bpe_codes)
try:
from subword_nmt import apply_bpe
bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args(
[
"--codes",
codes,
"--separator",
cfg.bpe_separator,
]
)
self.bpe = apply_bpe.BPE(
bpe_args.codes,
bpe_args.merges,
bpe_args.separator,
None,
bpe_args.glossaries,
)
self.bpe_symbol = bpe_args.separator + " "
except ImportError:
raise ImportError(
"Please install subword_nmt with: pip install subword-nmt"
)
def encode(self, x: str) -> str:
return self.bpe.process_line(x)
def decode(self, x: str) -> str:
return (x + " ").replace(self.bpe_symbol, "").rstrip()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.data import encoders
def get_whole_word_mask(args, dictionary):
bpe = encoders.build_bpe(args)
if bpe is not None:
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
if tok.startswith("madeupword"):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(
list(map(is_beginning_of_word, range(len(dictionary))))
)
return mask_whole_words
return None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch.utils.data
from fairseq.data import data_utils
logger = logging.getLogger(__name__)
class EpochListening:
"""Mixin for receiving updates whenever the epoch increments."""
@property
def can_reuse_epoch_itr_across_epochs(self):
"""
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
this dataset across epochs.
This needs to return ``False`` if the sample sizes can change across
epochs, in which case we may need to regenerate batches at each epoch.
If your dataset relies in ``set_epoch`` then you should consider setting
this to ``False``.
"""
return True
def set_epoch(self, epoch):
"""Will receive the updated epoch number at the beginning of the epoch."""
pass
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise NotImplementedError
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
raise NotImplementedError
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self), dtype=np.int64)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False
def attr(self, attr: str, index: int):
return getattr(self, attr, None)
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
def get_batch_shapes(self):
"""
Return a list of valid batch shapes, for example::
[(8, 512), (16, 256), (32, 128)]
The first dimension of each tuple is the batch size and can be ``None``
to automatically infer the max batch size based on ``--max-tokens``.
The second dimension of each tuple is the max supported length as given
by :func:`fairseq.data.FairseqDataset.num_tokens`.
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
to restrict batch shapes. This is useful on TPUs to avoid too many
dynamic shapes (and recompilations).
"""
return None
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
"""
Given an ordered set of indices, return batches according to
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
"""
from fairseq.data import data_utils
fixed_shapes = self.get_batch_shapes()
if fixed_shapes is not None:
def adjust_bsz(bsz, num_tokens):
if bsz is None:
assert max_tokens is not None, "Must specify --max-tokens"
bsz = max_tokens // num_tokens
if max_sentences is not None:
bsz = min(bsz, max_sentences)
elif (
bsz >= required_batch_size_multiple
and bsz % required_batch_size_multiple != 0
):
bsz -= bsz % required_batch_size_multiple
return bsz
fixed_shapes = np.array(
[
[adjust_bsz(bsz, num_tokens), num_tokens]
for (bsz, num_tokens) in fixed_shapes
]
)
try:
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
except NotImplementedError:
num_tokens_vec = None
return data_utils.batch_by_size(
indices,
num_tokens_fn=self.num_tokens,
num_tokens_vec=num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
fixed_shapes=fixed_shapes,
)
def filter_indices_by_size(self, indices, max_sizes):
"""
Filter a list of sample indices. Remove those that are longer than
specified in *max_sizes*.
WARNING: don't update, override method in child classes
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
ignored = indices[self.sizes[indices] > max_sizes].tolist()
indices = indices[self.sizes[indices] <= max_sizes]
elif (
hasattr(self, "sizes")
and isinstance(self.sizes, list)
and len(self.sizes) == 1
):
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
indices = indices[self.sizes[0][indices] <= max_sizes]
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
return indices, ignored
@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return True
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
"""
For datasets that need to be read sequentially, usually because the data is
being streamed or otherwise can't be manipulated on a single machine.
"""
def __iter__(self):
raise NotImplementedError
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import subprocess
import threading
from pathlib import Path
import numpy as np
import torch
def fasta_file_path(prefix_path):
return prefix_path + ".fasta"
class FastaDataset(torch.utils.data.Dataset):
"""
For loading protein sequence datasets in the common FASTA data format
"""
def __init__(self, path: str, cache_indices=False):
self.fn = fasta_file_path(path)
self.threadlocal = threading.local()
self.cache = Path(f"{path}.fasta.idx.npy")
if cache_indices:
if self.cache.exists():
self.offsets, self.sizes = np.load(self.cache)
else:
self.offsets, self.sizes = self._build_index(path)
np.save(self.cache, np.stack([self.offsets, self.sizes]))
else:
self.offsets, self.sizes = self._build_index(path)
def _get_file(self):
if not hasattr(self.threadlocal, "f"):
self.threadlocal.f = open(self.fn, "r")
return self.threadlocal.f
def __getitem__(self, idx):
f = self._get_file()
f.seek(self.offsets[idx])
desc = f.readline().strip()
line = f.readline()
seq = ""
while line != "" and line[0] != ">":
seq += line.strip()
line = f.readline()
return desc, seq
def __len__(self):
return self.offsets.size
def _build_index(self, path: str):
# Use grep and awk to get 100M/s on local SSD.
# Should process your enormous 100G fasta in ~10 min single core...
path = fasta_file_path(path)
bytes_offsets = subprocess.check_output(
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
"| grep --byte-offset '^>' -o | cut -d: -f1",
shell=True,
)
fasta_lengths = subprocess.check_output(
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
"| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
shell=True,
)
bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
return bytes_np, sizes_np
def __setstate__(self, state):
self.__dict__ = state
self.threadlocal = threading.local()
def __getstate__(self):
d = {}
for i, v in self.__dict__.items():
if i != "threadlocal":
d[i] = v
return d
def __del__(self):
if hasattr(self.threadlocal, "f"):
self.threadlocal.f.close()
del self.threadlocal.f
@staticmethod
def exists(path):
return os.path.exists(fasta_file_path(path))
class EncodedFastaDataset(FastaDataset):
"""
The FastaDataset returns raw sequences - this allows us to return
indices with a dictionary instead.
"""
def __init__(self, path, dictionary):
super().__init__(path, cache_indices=True)
self.dictionary = dictionary
def __getitem__(self, idx):
desc, seq = super().__getitem__(idx)
return self.dictionary.encode_line(seq, line_tokenizer=list).long()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
from .huffman_mmap_indexed_dataset import (
HuffmanMMapIndex,
HuffmanMMapIndexedDataset,
HuffmanMMapIndexedDatasetBuilder,
vocab_file_path,
)
__all__ = [
"HuffmanCoder",
"HuffmanCodeBuilder",
"HuffmanMMapIndexedDatasetBuilder",
"HuffmanMMapIndexedDataset",
"HuffmanMMapIndex",
"vocab_file_path",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
import typing as tp
from collections import Counter, deque
from dataclasses import dataclass
from bitarray import bitarray, util
from fairseq.data import Dictionary
# basically we have to write to addressable bytes for the memory mapped
# dataset loader. Sentences that get encoded to a length that is not a
# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
BLOCKSIZE = 8
class HuffmanCoder:
def __init__(
self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>"
):
self.root = root
self.table = root.code_table()
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
def _pad(self, a: bitarray) -> bitarray:
"""
bitpadding, 1 then 0.
If the array is already a multiple of blocksize, we add a full block.
"""
pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1
padding = bitarray("1" + "0" * pad_len)
return a + padding
def _unpad(self, a: bitarray) -> bitarray:
"""
remove the bitpadding.
There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
"""
# count the 0 padding at the end until we find the first 1
# we want to remove the one too
remove_cnt = util.rindex(a, 1)
return a[:remove_cnt]
def encode(self, iter: tp.List[str]) -> bytes:
"""
encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
"""
a = bitarray()
for token in iter:
code = self.get_code(token)
if code is None:
if self.unk_word is None:
raise Exception(f"unknown token {token} cannot be encoded.")
else:
token = self.unk_word
a = a + self.get_code(token)
return self._pad(a).tobytes()
def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]:
"""
take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
"""
a = bitarray()
a.frombytes(bits)
return self.root.decode(self._unpad(a))
def get_code(self, symbol: str) -> tp.Optional[bitarray]:
node = self.get_node(symbol)
return None if node is None else node.code
def get_node(self, symbol: str) -> "HuffmanNode":
return self.table.get(symbol)
@classmethod
def from_file(
cls,
filename: str,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
) -> "HuffmanCoder":
builder = HuffmanCodeBuilder.from_file(filename)
return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk)
def to_file(self, filename, sep="\t"):
nodes = list(self.table.values())
nodes.sort(key=lambda n: n.id)
with open(filename, "w", encoding="utf-8") as output:
for n in nodes:
output.write(f"{n.symbol}{sep}{n.count}\n")
def __iter__(self):
for n in self.table.values():
yield n
def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder":
builder = HuffmanCodeBuilder()
for n in self:
builder.increment(n.symbol, n.count)
for n in other_coder:
builder.increment(n.symbol, n.count)
return builder.build_code()
def __eq__(self, other: "HuffmanCoder") -> bool:
return self.table == other.table
def __len__(self) -> int:
return len(self.table)
def __contains__(self, sym: str) -> bool:
return sym in self.table
def to_dictionary(self) -> Dictionary:
dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos)
for n in self:
dictionary.add_symbol(n.symbol, n=n.count)
dictionary.finalize()
return dictionary
@dataclass
class HuffmanNode:
"""
a node in a Huffman tree
"""
id: int
count: int
symbol: tp.Optional[str] = None
left: tp.Optional["HuffmanNode"] = None
right: tp.Optional["HuffmanNode"] = None
code: tp.Optional[bitarray] = None
def is_leaf(self) -> bool:
return self.left is None and self.right is None
def code_table(
self, prefix: tp.Optional[bitarray] = None
) -> tp.Dict[str, "HuffmanNode"]:
defaulted_prefix = prefix if prefix is not None else bitarray()
if self.is_leaf():
self.code = (
defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0")
) # leaf could be the root if there is only one symbol
return {self.symbol: self}
codes_right = self.right.code_table(defaulted_prefix + bitarray([0]))
codes_left = self.left.code_table(defaulted_prefix + bitarray([1]))
return {**codes_left, **codes_right}
def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]:
current_node = self
for bit in bits:
if bit == 0: # go right
current_node = current_node.right
else: # go left
current_node = current_node.left
if current_node is None:
# we shouldn't be on a leaf here
raise Exception("fell off a leaf")
if current_node.is_leaf():
yield current_node
current_node = self
if current_node != self:
raise Exception("couldn't decode all the bits")
class HuffmanCodeBuilder:
"""
build a dictionary with occurence count and then build the Huffman code for it.
"""
def __init__(self):
self.symbols = Counter()
def add_symbols(self, *syms) -> None:
self.symbols.update(syms)
def increment(self, symbol: str, cnt: int) -> None:
self.symbols[symbol] += cnt
@classmethod
def from_file(cls, filename):
c = cls()
with open(filename, "r", encoding="utf-8") as input:
for line in input:
split = re.split(r"[\s]+", line)
c.increment(split[0], int(split[1]))
return c
def to_file(self, filename, sep="\t"):
with open(filename, "w", encoding="utf-8") as output:
for (tok, cnt) in self.symbols.most_common():
output.write(f"{tok}{sep}{cnt}\n")
def _smallest(self, q1: deque, q2: deque) -> HuffmanNode:
if len(q1) == 0:
return q2.pop()
if len(q2) == 0:
return q1.pop()
if q1[-1].count < q2[-1].count:
return q1.pop()
return q2.pop()
def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder":
new_c = self.symbols + c.symbols
new_b = HuffmanCodeBuilder()
new_b.symbols = new_c
return new_b
def build_code(
self,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
) -> HuffmanCoder:
assert len(self.symbols) > 0, "cannot build code from empty list of symbols"
if self.symbols[bos] == 0:
self.add_symbols(bos)
if self.symbols[pad] == 0:
self.add_symbols(pad)
if self.symbols[eos] == 0:
self.add_symbols(eos)
if self.symbols[unk] == 0:
self.add_symbols(unk)
node_id = 0
leaves_queue = deque(
[
HuffmanNode(symbol=symbol, count=count, id=idx)
for idx, (symbol, count) in enumerate(self.symbols.most_common())
]
) # left are the most common, right are the least common
if len(leaves_queue) == 1:
root = leaves_queue.pop()
root.id = 0
return HuffmanCoder(root)
nodes_queue = deque()
while len(leaves_queue) > 0 or len(nodes_queue) != 1:
# get the lowest two nodes at the head of each queue
node1 = self._smallest(leaves_queue, nodes_queue)
node2 = self._smallest(leaves_queue, nodes_queue)
# add new node
nodes_queue.appendleft(
HuffmanNode(
count=node1.count + node2.count, left=node1, right=node2, id=node_id
)
)
node_id += 1
# we are left with the root
return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import mmap
import os
import shutil
import struct
import typing as tp
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import indexed_dataset
from fairseq.data.huffman import HuffmanCoder
from fairseq.file_io import PathManager
class HuffmanMMapIndex:
"""
keep an index of the offsets in the huffman binary file.
First a header, then the list of sizes (num tokens) for each instance and finally
the addresses of each instance.
"""
_HDR_MAGIC = b"HUFFIDX\x00\x00"
_VERSION = 1
@classmethod
def writer(cls, path: str, data_len: int):
class _Writer:
def __enter__(self):
self._file = open(path, "wb")
# write header (magic + version)
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", cls._VERSION))
self._file.write(struct.pack("<Q", data_len))
return self
def write(self, sizes, pointers):
# add number of items in the index to the header
self._file.write(struct.pack("<Q", len(sizes)))
# write sizes
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
# write address pointers
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, "rb") as stream:
# read headers
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
(version,) = struct.unpack("<Q", stream.read(8))
assert (
self._VERSION == version
), f"Unexpected file version{version} != code version {self._VERSION}"
# read length of data file
(self._data_len,) = struct.unpack("<Q", stream.read(8))
# read number of items in data file/index
(self._len,) = struct.unpack("<Q", stream.read(8))
offset = stream.tell()
indexed_dataset._warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
def __iter__(self):
for i in range(self._len):
yield self[i]
@property
def data_len(self):
return self._data_len
@property
def sizes(self):
return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def vocab_file_path(prefix_path):
return prefix_path + ".vocab"
class HuffmanMMapIndexedDataset(torch.utils.data.Dataset):
"""
an indexed dataset that use mmap and memoryview to access data from disk
that was compressed with a HuffmanCoder.
"""
def __init__(self, prefix_path):
super().__init__()
self._prefix_path = None
self._index = None
self._bin_buffer = None
self._coder = None
self._file = None
self._bin_buffer_mmap = None
self._do_init(prefix_path)
def __getstate__(self):
return self._prefix_path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, prefix_path):
self._prefix_path = prefix_path
self._index = HuffmanMMapIndex(
indexed_dataset.index_file_path(self._prefix_path)
)
self._coder = HuffmanCoder.from_file(vocab_file_path(self._prefix_path))
indexed_dataset._warmup_mmap_file(
indexed_dataset.data_file_path(self._prefix_path)
)
self._file = os.open(
indexed_dataset.data_file_path(self._prefix_path), os.O_RDONLY
)
self._bin_buffer_mmap = mmap.mmap(
self._file,
self._index.data_len,
access=mmap.ACCESS_READ,
)
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
del self._bin_buffer
if self._file:
os.close(self._file)
del self._index
def __len__(self):
return len(self._index)
def _decode(self, i):
ptr, _ = self._index[i]
if i == 0:
raw_bytes = self._bin_buffer[:ptr]
else:
(prev_ptr, _) = self._index[i - 1]
raw_bytes = self._bin_buffer[prev_ptr:ptr]
return self._coder.decode(raw_bytes.tobytes())
@lru_cache(maxsize=8)
def __getitem__(self, i):
nodes = self._decode(i)
return torch.tensor([n.id for n in nodes], dtype=torch.int64)
def __iter__(self):
for idx in range(len(self)):
yield self[idx]
def get_symbols(self, i):
nodes = self._decode(i)
for n in nodes:
yield n.symbol
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@property
def coder(self):
return self._coder
@staticmethod
def exists(prefix_path):
return (
PathManager.exists(indexed_dataset.index_file_path(prefix_path))
and PathManager.exists(indexed_dataset.data_file_path(prefix_path))
and PathManager.exists(vocab_file_path(prefix_path))
)
class HuffmanMMapIndexedDatasetBuilder:
"""
Helper to build a memory mapped datasets with a huffman encoder.
You can either open/close this manually or use it as a ContextManager.
Provide your own coder, it will then be stored alongside the dataset.
The builder will first write the vocab file, then open the binary file so you can stream
into it, finally the index will be written when the builder is closed (your index should fit in memory).
"""
def __init__(self, path_prefix: str, coder: HuffmanCoder) -> None:
self._path_prefix = path_prefix
self._coder = coder
self._sizes = []
self._ptrs = []
self._data_len = 0
def open(self):
self._coder.to_file(vocab_file_path(self._path_prefix))
self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb")
def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder":
self.open()
return self
def add_item(self, tokens: tp.List[str]) -> None:
"""
add a list of tokens to the dataset, they will compressed with the
provided coder before being written to file.
"""
encoded = self._coder.encode(tokens)
code_len = len(encoded)
last_ptr = 0
if len(self._ptrs) > 0:
last_ptr = self._ptrs[-1]
self._sizes.append(len(tokens))
self._ptrs.append(last_ptr + code_len)
self._data_len += code_len
self._data_file.write(encoded)
def append(self, other_dataset_path_prefix: str) -> None:
"""
append an existing dataset.
Beware, if it wasn't built with the same coder, you are in trouble.
"""
other_index = HuffmanMMapIndex(
indexed_dataset.index_file_path(other_dataset_path_prefix)
)
for (ptr, size) in other_index:
self._ptrs.append(ptr + self._data_len)
self._sizes.append(size)
# Concatenate data
with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f:
shutil.copyfileobj(f, self._data_file)
self._data_len += other_index.data_len
def close(self):
self._data_file.close()
with HuffmanMMapIndex.writer(
indexed_dataset.index_file_path(self._path_prefix), self._data_len
) as index:
index.write(self._sizes, self._ptrs)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class IdDataset(FairseqDataset):
def __getitem__(self, index):
return index
def __len__(self):
return 0
def collater(self, samples):
return torch.tensor(samples)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import shutil
import struct
from functools import lru_cache
import numpy as np
import torch
from fairseq.dataclass.constants import DATASET_IMPL_CHOICES
from fairseq.data.fasta_dataset import FastaDataset
from fairseq.file_io import PathManager
from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex
from . import FairseqDataset
from typing import Union
def best_fitting_int_dtype(
max_int_to_represent,
) -> Union[np.uint16, np.uint32, np.int64]:
if max_int_to_represent is None:
return np.uint32 # Safe guess
elif max_int_to_represent < 65500:
return np.uint16
elif max_int_to_represent < 4294967295:
return np.uint32
else:
return np.int64
# we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly
# https://github.com/numpy/numpy/issues/5745
def get_available_dataset_impl():
return list(map(str, DATASET_IMPL_CHOICES))
def infer_dataset_impl(path):
if IndexedRawTextDataset.exists(path):
return "raw"
elif IndexedDataset.exists(path):
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return "cached"
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return "mmap"
elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]:
return "huffman"
else:
return None
elif FastaDataset.exists(path):
return "fasta"
else:
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(
out_file, dtype=best_fitting_int_dtype(vocab_size)
)
elif impl == "fasta":
raise NotImplementedError
elif impl == "huffman":
raise ValueError(
"Use HuffmanCodeBuilder directly as it has a different interface."
)
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
if impl == "raw" and IndexedRawTextDataset.exists(path):
assert dictionary is not None
return IndexedRawTextDataset(path, dictionary)
elif impl == "lazy" and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == "cached" and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == "mmap" and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
elif impl == "fasta" and FastaDataset.exists(path):
from fairseq.data.fasta_dataset import EncodedFastaDataset
return EncodedFastaDataset(path, dictionary)
elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path):
return HuffmanMMapIndexedDataset(path)
return None
def dataset_exists(path, impl):
if impl == "raw":
return IndexedRawTextDataset.exists(path)
elif impl == "mmap":
return MMapIndexedDataset.exists(path)
elif impl == "huffman":
return HuffmanMMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
_code_to_dtype = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float64,
7: np.double,
8: np.uint16,
9: np.uint32,
10: np.uint64,
}
def _dtype_header_code(dtype) -> int:
for k in _code_to_dtype.keys():
if _code_to_dtype[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class IndexedDataset(FairseqDataset):
"""Loader for TorchNet IndexedDataset"""
_HDR_MAGIC = b"TNTIDX\x00\x00"
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
self.path = path
self.fix_lua_indexing = fix_lua_indexing
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = f.read(8)
assert struct.unpack("<Q", version) == (1,)
code, self.element_size = struct.unpack("<QQ", f.read(16))
self.dtype = _code_to_dtype[code]
self._len, self.s = struct.unpack("<QQ", f.read(16))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
def read_data(self, path):
self.data_file = open(data_file_path(path), "rb", buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError("index out of range")
def __del__(self):
if self.data_file:
self.data_file.close()
@lru_cache(maxsize=8)
def __getitem__(self, i) -> torch.Tensor:
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return PathManager.exists(index_file_path(path)) and PathManager.exists(
data_file_path(path)
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
class IndexedRawTextDataset(FairseqDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, "r", encoding="utf-8") as f:
for line in f:
self.lines.append(line.strip("\n"))
tokens = dictionary.encode_line(
line,
add_if_not_exist=False,
append_eos=self.append_eos,
reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError("index out of range")
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return PathManager.exists(path)
class IndexedDatasetBuilder:
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float64: 4,
np.double: 8,
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, "wb")
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
def add_item(self, tensor):
# +1 for Lua compatibility
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), "rb") as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, "wb")
index.write(b"TNTIDX\x00\x00")
index.write(struct.pack("<Q", 1))
index.write(
struct.pack("<QQ", _dtype_header_code(self.dtype), self.element_size)
)
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()
def _warmup_mmap_file(path):
with open(path, "rb") as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index:
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer:
def __enter__(self):
self._file = open(path, "wb")
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", _dtype_header_code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack("<Q", len(sizes)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = _code_to_dtype[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path):
self._path = path
self._index = self.Index(index_file_path(self._path))
_warmup_mmap_file(data_file_path(self._path))
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
@lru_cache(maxsize=8)
def __getitem__(self, i):
ptr, size = self._index[i]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)
return torch.from_numpy(np_array)
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return PathManager.exists(index_file_path(path)) and PathManager.exists(
data_file_path(path)
)
def get_indexed_dataset_to_local(path) -> str:
local_index_path = PathManager.get_local_path(index_file_path(path))
local_data_path = PathManager.get_local_path(data_file_path(path))
assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), (
"PathManager.get_local_path does not return files with expected patterns: "
f"{local_index_path} and {local_data_path}"
)
local_path = local_data_path[:-4] # stripping surfix ".bin"
assert local_path == local_index_path[:-4] # stripping surfix ".idx"
return local_path
class MMapIndexedDatasetBuilder:
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, "wb")
self._dtype = dtype
self._sizes = []
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order="C"))
self._sizes.append(np_array.size)
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), "rb") as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment