Commit b450a7fa authored by thomwolf's avatar thomwolf
Browse files

clean up tokenization - fix python 2 tests

parent d44db114
...@@ -20,14 +20,19 @@ import json ...@@ -20,14 +20,19 @@ import json
import logging import logging
import os import os
import regex as re import regex as re
import sys
from io import open from io import open
from functools import lru_cache
from tqdm import tqdm try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache(func):
def func_wrapper(*inputs, **args):
return func(inputs, args)
return func_wrapper
from .file_utils import cached_path from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -125,7 +130,8 @@ class GPT2Tokenizer(object): ...@@ -125,7 +130,8 @@ class GPT2Tokenizer(object):
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace'): def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
...@@ -188,6 +194,12 @@ class GPT2Tokenizer(object): ...@@ -188,6 +194,12 @@ class GPT2Tokenizer(object):
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
)
return bpe_tokens return bpe_tokens
def decode(self, tokens): def decode(self, tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment