Commit bc70779b authored by thomwolf's avatar thomwolf
Browse files

fixed GPT-2 tokenization on python 2

parent bdaba189
...@@ -227,7 +227,7 @@ def get_from_cache(url, cache_dir=None): ...@@ -227,7 +227,7 @@ def get_from_cache(url, cache_dir=None):
meta = {'url': url, 'etag': etag} meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file: with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file) meta_file.write(json.dumps(meta))
logger.info("removing temp file %s", temp_file.name) logger.info("removing temp file %s", temp_file.name)
......
...@@ -59,6 +59,7 @@ def bytes_to_unicode(): ...@@ -59,6 +59,7 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
...@@ -67,7 +68,7 @@ def bytes_to_unicode(): ...@@ -67,7 +68,7 @@ def bytes_to_unicode():
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8+n)
n += 1 n += 1
cs = [chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
...@@ -219,7 +220,7 @@ class GPT2Tokenizer(object): ...@@ -219,7 +220,7 @@ class GPT2Tokenizer(object):
""" Tokenize a string. """ """ Tokenize a string. """
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.byte_encoder[ord(b)] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens return bpe_tokens
......
...@@ -31,13 +31,14 @@ class GPT2TokenizationTest(unittest.TestCase): ...@@ -31,13 +31,14 @@ class GPT2TokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
json.dump(vocab_tokens, fp) fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
merges_file = fp.name merges_file = fp.name
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
print("encoder", tokenizer.byte_encoder)
os.remove(vocab_file) os.remove(vocab_file)
os.remove(merges_file) os.remove(merges_file)
......
...@@ -32,7 +32,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -32,7 +32,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
json.dump(vocab_tokens, fp) fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment