Commit c6bea084 authored by thomwolf's avatar thomwolf
Browse files

OpenAI GPT Tokenizer can fallback on using BERT BasicTokenizer

parent e7cfc46f
......@@ -26,6 +26,7 @@ from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__)
......@@ -72,8 +73,9 @@ class OpenAIGPTTokenizer(object):
"""
BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
......@@ -122,12 +124,15 @@ class OpenAIGPTTokenizer(object):
try:
import ftfy
import spacy
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
except ImportError:
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True,
never_split=special_tokens if special_tokens is not None else [])
self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12)
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
......@@ -150,6 +155,9 @@ class OpenAIGPTTokenizer(object):
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
......@@ -198,6 +206,13 @@ class OpenAIGPTTokenizer(object):
def tokenize(self, text):
""" Tokenize a string. """
split_tokens = []
if self.fix_text is None:
# Using BERT's BasicTokenizer
text = self.nlp.tokenize(text)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
......@@ -219,8 +234,8 @@ class OpenAIGPTTokenizer(object):
if len(ids) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment