Commit c6bea084 authored by thomwolf's avatar thomwolf
Browse files

OpenAI GPT Tokenizer can fallback on using BERT BasicTokenizer

parent e7cfc46f
...@@ -26,6 +26,7 @@ from io import open ...@@ -26,6 +26,7 @@ from io import open
from tqdm import tqdm from tqdm import tqdm
from .file_utils import cached_path from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -72,8 +73,9 @@ class OpenAIGPTTokenizer(object): ...@@ -72,8 +73,9 @@ class OpenAIGPTTokenizer(object):
""" """
BPE tokenizer. Peculiarities: BPE tokenizer. Peculiarities:
- lower case all inputs - lower case all inputs
- uses SpaCy tokenizer - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary. - argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
""" """
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
...@@ -122,12 +124,15 @@ class OpenAIGPTTokenizer(object): ...@@ -122,12 +124,15 @@ class OpenAIGPTTokenizer(object):
try: try:
import ftfy import ftfy
import spacy import spacy
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
except ImportError: except ImportError:
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.") logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True,
never_split=special_tokens if special_tokens is not None else [])
self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
...@@ -150,6 +155,9 @@ class OpenAIGPTTokenizer(object): ...@@ -150,6 +155,9 @@ class OpenAIGPTTokenizer(object):
return return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
...@@ -198,9 +206,16 @@ class OpenAIGPTTokenizer(object): ...@@ -198,9 +206,16 @@ class OpenAIGPTTokenizer(object):
def tokenize(self, text): def tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
split_tokens = [] split_tokens = []
text = self.nlp(text_standardize(self.fix_text(text))) if self.fix_text is None:
for token in text: # Using BERT's BasicTokenizer
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) text = self.nlp.tokenize(text)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
...@@ -219,8 +234,8 @@ class OpenAIGPTTokenizer(object): ...@@ -219,8 +234,8 @@ class OpenAIGPTTokenizer(object):
if len(ids) > self.max_len: if len(ids) > self.max_len:
raise ValueError( raise ValueError(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this" " sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
) )
return ids return ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment