Commit 03c2c762 authored by thomwolf's avatar thomwolf
Browse files

update tokenizer

parent 3edfa1d6
...@@ -16,21 +16,13 @@ ...@@ -16,21 +16,13 @@
from __future__ import (absolute_import, division, print_function, from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
import sys
import json import json
import logging import logging
import os import os
import regex as re import regex as re
from io import open from io import open
import pdb
try: from .tokenization_bert import BasicTokenizer
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
...@@ -53,49 +45,47 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -53,49 +45,47 @@ PRETRAINED_VOCAB_FILES_MAP = {
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl': 1280, 'ctrl': 256,
} }
@lru_cache() def text_standardize(text):
def bytes_to_unicode():
""" """
Returns list of utf-8 byte and a mapping to unicode strings. fixes some issues the spacy tokenizer had on books corpus
We specifically avoids mapping to whitespace/control characters the bpe code barfs on. also does some whitespace standardization
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr text = text.replace('—', '-')
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) text = text.replace('–', '-')
cs = bs[:] text = text.replace('―', '-')
n = 0 text = text.replace('…', '...')
for b in range(2**8): text = text.replace('´', "'")
if b not in bs: text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
bs.append(b) text = re.sub(r'\s*\n\s*', ' \n ', text)
cs.append(2**8+n) text = re.sub(r'[^\S\n]+', ' ', text)
n += 1 return text.strip()
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings). Word is represented as tuple of symbols (symbols being variable-length strings).
""" """
pairs= [] # pairs = []
# prev_char = word[0]
# for i, char in enumerate(word[1:]):
# #_i = i + 1
# #if word[_i+1:] == tuple('</w>'):
# # pairs.append((prev_char, char+'</w>'))
# # break
# #else:
# if True:
# pairs.append((prev_char, char))
# prev_char = char
pairs = set()
prev_char = word[0] prev_char = word[0]
for i, char in enumerate(word[1:]): for char in word[1:]:
#_i = i + 1 pairs.add((prev_char, char))
#if word[_i+1:] == tuple('</w>'): prev_char = char
# pairs.append((prev_char, char+'</w>'))
# break
#else:
if True:
pairs.append((prev_char, char))
prev_char = char
pairs = set(pairs) pairs = set(pairs)
return pairs return pairs
...@@ -113,25 +103,29 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -113,25 +103,29 @@ class CTRLTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<unk>", def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
super(CTRLTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
try:
import ftfy
from spacy.lang.en import English
_nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True)
self.fix_text = None
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
self.byte_encoder = bytes_to_unicode() merges = [tuple(merge.split()) for merge in merges]
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges))))
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
...@@ -179,23 +173,27 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -179,23 +173,27 @@ class CTRLTokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def _tokenize(self, text, add_prefix_space=False): def _tokenize(self, text):
""" Tokenize a string. """ Tokenize a string.
Args:
- add_prefix_space (boolean, default False):
Begin the sentence with at least one space toto get invariance to word order in CTRL (and RoBERTa) tokenizers.
""" """
if add_prefix_space: split_tokens = []
text = ' ' + text if self.fix_text is None:
# Using BERT's BasicTokenizer
bpe_tokens = [] text = self.nlp.tokenize(text)
for token in text.split(): for token in text:
if sys.version_info[0] == 2: split_tokens.extend([t for t in self.bpe(token).split(' ')])
token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) else:
else: # Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case) text = self.nlp(text_standardize(self.fix_text(text)))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) for token in text:
return bpe_tokens split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
# for token in text.split():
# if sys.version_info[0] == 2:
# token = ''.join(self.byte_encoder[ord(b)] for b in token) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
# else:
# token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
# bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
...@@ -203,13 +201,12 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -203,13 +201,12 @@ class CTRLTokenizer(PreTrainedTokenizer):
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens) out_string = ''.join(tokens).replace('@@', ' ').strip()
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return out_string
return text
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
...@@ -235,10 +232,8 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -235,10 +232,8 @@ class CTRLTokenizer(PreTrainedTokenizer):
return vocab_file, merge_file return vocab_file, merge_file
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)) # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens) # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far) # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
return ''.join(tokens_generated_so_far) # return ''.join(tokens_generated_so_far)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment