Commit b450a7fa authored by thomwolf's avatar thomwolf
Browse files

clean up tokenization - fix python 2 tests

parent d44db114
......@@ -20,14 +20,19 @@ import json
import logging
import os
import regex as re
import sys
from io import open
from functools import lru_cache
from tqdm import tqdm
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache(func):
def func_wrapper(*inputs, **args):
return func(inputs, args)
return func_wrapper
from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__)
......@@ -125,7 +130,8 @@ class GPT2Tokenizer(object):
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace'):
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
......@@ -188,6 +194,12 @@ class GPT2Tokenizer(object):
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
)
return bpe_tokens
def decode(self, tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment