Commit 34ccc8eb authored by lukovnikov's avatar lukovnikov
Browse files

Merge remote-tracking branch 'upstream/master'

parents fc7693ad 68a889ee
...@@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { ...@@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
} }
VOCAB_NAME = 'vocab.json' VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt' MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
def get_pairs(word): def get_pairs(word):
""" """
...@@ -86,9 +87,15 @@ class OpenAIGPTTokenizer(object): ...@@ -86,9 +87,15 @@ class OpenAIGPTTokenizer(object):
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else: else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
...@@ -117,7 +124,11 @@ class OpenAIGPTTokenizer(object): ...@@ -117,7 +124,11 @@ class OpenAIGPTTokenizer(object):
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
...@@ -139,6 +150,8 @@ class OpenAIGPTTokenizer(object): ...@@ -139,6 +150,8 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens) self.set_special_tokens(special_tokens)
def __len__(self): def __len__(self):
...@@ -250,14 +263,51 @@ class OpenAIGPTTokenizer(object): ...@@ -250,14 +263,51 @@ class OpenAIGPTTokenizer(object):
tokens.append(self.decoder[i]) tokens.append(self.decoder[i])
return tokens return tokens
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string.""" """Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '') out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't" ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
).replace(" 've", "'ve")
return out_string return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
...@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object): ...@@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) if os.path.isdir(pretrained_model_name_or_path):
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
else:
vocab_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
...@@ -141,6 +144,14 @@ class TransfoXLTokenizer(object): ...@@ -141,6 +144,14 @@ class TransfoXLTokenizer(object):
else: else:
raise ValueError('No <unkown> token in vocabulary') raise ValueError('No <unkown> token in vocabulary')
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
torch.save(self.__dict__, vocab_file)
return vocab_file
def build_vocab(self): def build_vocab(self):
if self.vocab_file: if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file)) print('building vocab from {}'.format(self.vocab_file))
...@@ -245,82 +256,24 @@ class TransfoXLTokenizer(object): ...@@ -245,82 +256,24 @@ class TransfoXLTokenizer(object):
def __len__(self): def __len__(self):
return len(self.idx2sym) return len(self.idx2sym)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
if self.delimiter == '':
tokens = text
else:
tokens = text.split(self.delimiter)
return tokens
def tokenize(self, line, add_eos=False, add_double_eos=False): def tokenize(self, line, add_eos=False, add_double_eos=False):
line = self._clean_text(line)
line = line.strip() line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()
symbols = self.whitespace_tokenize(line) # empty delimiter '' will evaluate False
if self.delimiter == '':
split_symbols = [] symbols = line
for symbol in symbols: else:
if self.lower_case and symbol not in self.never_split: symbols = line.split(self.delimiter)
symbol = symbol.lower()
symbol = self._run_strip_accents(symbol)
split_symbols.extend(self._run_split_on_punc(symbol))
if add_double_eos: # lm1b if add_double_eos: # lm1b
return ['<S>'] + split_symbols + ['<S>'] return ['<S>'] + symbols + ['<S>']
elif add_eos: elif add_eos:
return split_symbols + ['<eos>'] return symbols + ['<eos>']
else: else:
return split_symbols return symbols
class LMOrderedIterator(object): class LMOrderedIterator(object):
...@@ -631,42 +584,3 @@ def get_lm_corpus(datadir, dataset): ...@@ -631,42 +584,3 @@ def get_lm_corpus(datadir, dataset):
torch.save(corpus, fn) torch.save(corpus, fn)
return corpus return corpus
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# content of conftest.py
import pytest
def pytest_addoption(parser):
parser.addoption(
"--runslow", action="store_true", default=False, help="run slow tests"
)
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
...@@ -16,15 +16,18 @@ from __future__ import absolute_import ...@@ -16,15 +16,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
import shutil
import pytest
import torch import torch
from pytorch_pretrained_bert import (GPT2Config, GPT2Model, from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel) GPT2LMHeadModel, GPT2DoubleHeadsModel)
from pytorch_pretrained_bert.modeling_gpt2 import PRETRAINED_MODEL_ARCHIVE_MAP
class GPT2ModelTest(unittest.TestCase): class GPT2ModelTest(unittest.TestCase):
class GPT2ModelTester(object): class GPT2ModelTester(object):
...@@ -176,6 +179,22 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -176,6 +179,22 @@ class GPT2ModelTest(unittest.TestCase):
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["n_embd"], 37) self.assertEqual(obj["n_embd"], 37)
def test_config_to_json_file(self):
config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = GPT2Config.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_gpt2_model(*config_and_inputs) output_result = tester.create_gpt2_model(*config_and_inputs)
......
...@@ -16,15 +16,18 @@ from __future__ import absolute_import ...@@ -16,15 +16,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
import shutil
import pytest
import torch import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from pytorch_pretrained_bert.modeling_openai import PRETRAINED_MODEL_ARCHIVE_MAP
class OpenAIGPTModelTest(unittest.TestCase): class OpenAIGPTModelTest(unittest.TestCase):
class OpenAIGPTModelTester(object): class OpenAIGPTModelTester(object):
...@@ -188,6 +191,22 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -188,6 +191,22 @@ class OpenAIGPTModelTest(unittest.TestCase):
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["n_embd"], 37) self.assertEqual(obj["n_embd"], 37)
def test_config_to_json_file(self):
config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = OpenAIGPTConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_openai_model(*config_and_inputs) output_result = tester.create_openai_model(*config_and_inputs)
......
...@@ -16,9 +16,12 @@ from __future__ import absolute_import ...@@ -16,9 +16,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
import shutil
import pytest
import torch import torch
...@@ -26,6 +29,7 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, ...@@ -26,6 +29,7 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) BertForTokenClassification)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
class BertModelTest(unittest.TestCase): class BertModelTest(unittest.TestCase):
...@@ -251,6 +255,22 @@ class BertModelTest(unittest.TestCase): ...@@ -251,6 +255,22 @@ class BertModelTest(unittest.TestCase):
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["hidden_size"], 37) self.assertEqual(obj["hidden_size"], 37)
def test_config_to_json_file(self):
config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = BertConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_bert_model(*config_and_inputs) output_result = tester.create_bert_model(*config_and_inputs)
......
...@@ -16,14 +16,17 @@ from __future__ import absolute_import ...@@ -16,14 +16,17 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json import json
import random import random
import shutil
import pytest
import torch import torch
from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
class TransfoXLModelTest(unittest.TestCase): class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object): class TransfoXLModelTester(object):
...@@ -186,6 +189,22 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -186,6 +189,22 @@ class TransfoXLModelTest(unittest.TestCase):
self.assertEqual(obj["n_token"], 96) self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_embed"], 37) self.assertEqual(obj["d_embed"], 37)
def test_config_to_json_file(self):
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = TransfoXLConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
class GPT2TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [13, 12, 16]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual(
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
...@@ -17,8 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,8 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
class OpenAIGPTTokenizationTest(unittest.TestCase): class OpenAIGPTTokenizationTest(unittest.TestCase):
...@@ -32,13 +34,13 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -32,13 +34,13 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
json.dump(vocab_tokens, fp) fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
merges_file = fp.name merges_file = fp.name
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"]) tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file) os.remove(vocab_file)
os.remove(merges_file) os.remove(merges_file)
...@@ -52,5 +54,26 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -52,5 +54,26 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual(
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,12 +17,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,12 +17,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace) _is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
class TokenizationTest(unittest.TestCase): class TokenizationTest(unittest.TestCase):
...@@ -46,6 +48,24 @@ class TokenizationTest(unittest.TestCase): ...@@ -46,6 +48,24 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer, from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
_is_control, _is_punctuation,
_is_whitespace)
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(unittest.TestCase):
...@@ -37,54 +37,44 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -37,54 +37,44 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.build_vocab() tokenizer.build_vocab()
os.remove(vocab_file) os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["hello", "!", "how", "are", "you", "?"]) ["hello", "!", "how", "are", "you", "?"])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_full_tokenizer_no_lower(self): def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False) tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["HeLLo", "!", "how", "Are", "yoU", "?"]) ["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_is_whitespace(self): @pytest.mark.slow
self.assertTrue(_is_whitespace(u" ")) def test_tokenizer_from_pretrained(self):
self.assertTrue(_is_whitespace(u"\t")) cache_dir = "/tmp/pytorch_pretrained_bert_test/"
self.assertTrue(_is_whitespace(u"\r")) for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
self.assertTrue(_is_whitespace(u"\n")) tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
self.assertTrue(_is_whitespace(u"\u00A0")) shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(_is_control(u"A"))
self.assertFalse(_is_control(u" "))
self.assertFalse(_is_control(u"\t"))
self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(_is_punctuation(u"."))
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment