Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
@register_bpe("byte_bpe")
class ByteBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--sentencepiece-model-path', type=str,
help='path to sentencepiece model')
# fmt: on
def __init__(self, args):
vocab = file_utils.cached_path(args.sentencepiece_model_path)
try:
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(vocab)
except ImportError:
raise ImportError(
"Please install sentencepiece with: pip install sentencepiece"
)
def encode(self, x: str) -> str:
byte_encoded = byte_encode(x)
return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
# excluding non-breaking space (160) here
PRINTABLE_LATIN = set(
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
)
BYTE_TO_BCHAR = {
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
def byte_encode(x: str) -> str:
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
def byte_decode(x: str) -> str:
try:
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
except ValueError:
return ""
def smart_byte_decode(x: str) -> str:
output = byte_decode(x)
if output == "":
# DP the best recovery (max valid chars) if it's broken
n_bytes = len(x)
f = [0 for _ in range(n_bytes + 1)]
pt = [0 for _ in range(n_bytes + 1)]
for i in range(1, n_bytes + 1):
f[i], pt[i] = f[i - 1], i - 1
for j in range(1, min(4, i) + 1):
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
f[i], pt[i] = f[i - j] + 1, i - j
cur_pt = n_bytes
while cur_pt > 0:
if f[cur_pt] == f[pt[cur_pt]] + 1:
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
cur_pt = pt[cur_pt]
return output
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
@register_bpe("bytes")
class Bytes(object):
def __init__(self, args):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
encoded = byte_encode(x)
escaped = encoded.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
@register_bpe("characters")
class Characters(object):
def __init__(self, args):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
escaped = x.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
@register_bpe("fastbpe")
class fastBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-codes', type=str,
help='path to fastBPE BPE')
# fmt: on
def __init__(self, args):
if args.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(args.bpe_codes)
try:
import fastBPE
self.bpe = fastBPE.fastBPE(codes)
self.bpe_symbol = "@@ "
except ImportError:
raise ImportError("Please install fastBPE with: pip install fastBPE")
def encode(self, x: str) -> str:
return self.bpe.apply([x])[0]
def decode(self, x: str) -> str:
return (x + " ").replace(self.bpe_symbol, "").rstrip()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from .gpt2_bpe_utils import get_encoder
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@register_bpe("gpt2")
class GPT2BPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--gpt2-encoder-json', type=str,
default=DEFAULT_ENCODER_JSON,
help='path to encoder.json')
parser.add_argument('--gpt2-vocab-bpe', type=str,
default=DEFAULT_VOCAB_BPE,
help='path to vocab.bpe')
# fmt: on
def __init__(self, args):
encoder_json = file_utils.cached_path(
getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON)
)
vocab_bpe = file_utils.cached_path(
getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE)
)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")
"""
Byte pair encoding utilities from GPT-2.
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
Original license: MIT
"""
import json
from functools import lru_cache
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
try:
import regex as re
self.re = re
except ImportError:
raise ImportError("Please install regex with: pip install regex")
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = self.re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in self.re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder.get(token, token) for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text
def get_encoder(encoder_json_path, vocab_bpe_path):
with open(encoder_json_path, "r") as f:
encoder = json.load(f)
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
@register_bpe("bert")
class BertBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-cased', action='store_true',
help='set for cased BPE',
default=False)
parser.add_argument('--bpe-vocab-file', type=str,
help='bpe vocab file.')
# fmt: on
def __init__(self, args):
try:
from transformers import BertTokenizer
except ImportError:
raise ImportError(
"Please install transformers with: pip install transformers"
)
if "bpe_vocab_file" in args:
self.bert_tokenizer = BertTokenizer(
args.bpe_vocab_file, do_lower_case=not args.bpe_cased
)
else:
vocab_file_name = (
"bert-base-cased" if args.bpe_cased else "bert-base-uncased"
)
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
def encode(self, x: str) -> str:
return " ".join(self.bert_tokenizer.tokenize(x))
def decode(self, x: str) -> str:
return self.bert_tokenizer.clean_up_tokenization(
self.bert_tokenizer.convert_tokens_to_string(x.split(" "))
)
def is_beginning_of_word(self, x: str) -> bool:
return not x.startswith("##")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment