Commit e468192e authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'pytorch-transformers' into xlnet

parents 9dd2c860 4ce237c8
......@@ -25,10 +25,10 @@ import pytest
import torch
from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
from pytorch_transformers import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
class GPT2ModelTest(unittest.TestCase):
......
......@@ -21,10 +21,10 @@ import pytest
import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel,
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
class OpenAIModelTest(unittest.TestCase):
......
......@@ -412,8 +412,8 @@ class GPTModelTester(object):
[[], []])
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model)
......
......@@ -25,10 +25,10 @@ import pytest
import torch
from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object):
......@@ -184,8 +184,8 @@ class TransfoXLModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -16,29 +16,25 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel
from pytorch_pretrained_bert.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
import logging
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
model = BertModel.from_pretrained(model_name)
model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, PreTrainedModel)
for value in loading_info.values():
self.assertEqual(len(value), 0)
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
......
......@@ -20,10 +20,10 @@ import unittest
import shutil
import pytest
from pytorch_pretrained_bert import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_pretrained_bert.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class XLMModelTest(unittest.TestCase):
......@@ -250,8 +250,8 @@ class XLMModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -25,10 +25,10 @@ import pytest
import torch
from pytorch_pretrained_bert import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object):
......@@ -278,8 +278,8 @@ class XLNetModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
......@@ -20,9 +20,9 @@ import unittest
import torch
from pytorch_pretrained_bert import BertAdam
from pytorch_pretrained_bert import OpenAIAdam
from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
from pytorch_transformers import BertAdam
from pytorch_transformers import OpenAIAdam
from pytorch_transformers.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
import numpy as np
......
......@@ -17,29 +17,28 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_bert import (BasicTokenizer,
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
_is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
"##ing", ",", "low", "lowest",
]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
tokenizer = BertTokenizer(vocab_file)
......@@ -47,16 +46,6 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
os.remove(vocab_file)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_chinese(self):
tokenizer = BasicTokenizer()
......@@ -88,7 +77,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab)
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(""), [])
......
......@@ -17,12 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class GPT2TokenizationTest(unittest.TestCase):
......@@ -30,39 +28,32 @@ class GPT2TokenizationTest(unittest.TestCase):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider"]
"low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
special_tokens_map = {"unk_token": "<unk>"}
with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [13, 12, 16]
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
os.remove(vocab_file)
os.remove(merges_file)
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
......@@ -17,12 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class OpenAIGPTTokenizationTest(unittest.TestCase):
......@@ -32,21 +30,21 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
......@@ -58,14 +56,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
......@@ -12,56 +12,109 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
from io import open
if sys.version_info[0] == 3:
unicode = str
import tempfile
import shutil
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
vocab_path="/tmp/"
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path)
for f in output_files:
os.remove(f)
with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
tester.assertListEqual(before_tokens, after_tokens)
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
text = "Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin"
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
tester.assertIsNotNone(tokenizer)
text = u"Munich and Berlin are nice cities"
subwords = tokenizer.tokenize(text)
with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"tokenizer.bin")
pickle.dump(tokenizer, open(filename, "wb"))
tokenizer_new = pickle.load(open(filename, "rb"))
subwords_loaded = tokenizer_new.tokenize(text)
tester.assertListEqual(subwords, subwords_loaded)
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
tester.assertNotEqual(vocab_size, 0)
tester.assertEqual(vocab_size, all_size)
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
tester.assertNotEqual(vocab_size_2, 0)
tester.assertEqual(vocab_size, vocab_size_2)
tester.assertEqual(added_toks, len(new_toks))
tester.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
tester.assertGreaterEqual(len(tokens), 4)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
tester.assertNotEqual(vocab_size_3, 0)
tester.assertEqual(vocab_size, vocab_size_3)
tester.assertEqual(added_toks_2, len(new_toks_2))
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
tester.assertGreaterEqual(len(tokens), 6)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[0], tokens[1])
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokens[-3])
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running"
tokens = tokenizer.tokenize(text)
......@@ -77,5 +130,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
......@@ -17,27 +17,26 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l",
]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
......@@ -59,13 +58,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import six
from pytorch_transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]:
tokenizer = tokenizer_class.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, tokenizer_class)
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
for special_tok in tokenizer.all_special_tokens:
if six.PY2:
self.assertIsInstance(special_tok, unicode)
else:
self.assertIsInstance(special_tok, str)
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int)
def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
if __name__ == "__main__":
unittest.main()
......@@ -17,12 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class XLMTokenizationTest(unittest.TestCase):
......@@ -31,21 +29,21 @@ class XLMTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
tokenizer = XLMTokenizer(vocab_file, merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
......@@ -57,14 +55,6 @@ class XLMTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = XLMTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
......@@ -16,14 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer,
PRETRAINED_VOCAB_ARCHIVE_MAP,
SPIECE_UNDERLINE)
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from.tokenization_tests_commons import create_and_check_tokenizer_commons
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model')
......@@ -31,10 +27,13 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
......@@ -60,14 +59,6 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
......
......@@ -22,12 +22,15 @@ import os
import unicodedata
from io import open
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
......@@ -41,8 +44,10 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
......@@ -57,7 +62,6 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad': 512,
'bert-base-cased-finetuned-mrpc': 512,
}
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
......@@ -83,7 +87,7 @@ def whitespace_tokenize(text):
return tokens
class BertTokenizer(object):
class BertTokenizer(PreTrainedTokenizer):
r"""
Constructs a BertTokenizer.
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
......@@ -98,8 +102,26 @@ class BertTokenizer(object):
do_wordpiece_only=False
"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", **kwargs):
"""Constructs a BertTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
......@@ -111,97 +133,41 @@ class BertTokenizer(object):
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
@property
def UNK_TOKEN(self):
return "[UNK]"
@property
def SEP_TOKEN(self):
return "[SEP]"
@property
def PAD_TOKEN(self):
return "[PAD]"
@property
def CLS_TOKEN(self):
return "[CLS]"
@property
def MASK_TOKEN(self):
return "[MASK]"
@property
def UNK_ID(self):
return self.vocab["[UNK]"]
@property
def SEP_ID(self):
return self.vocab["[SEP]"]
@property
def PAD_ID(self):
return self.vocab["[PAD]"]
@property
def CLS_ID(self):
return self.vocab["[CLS]"]
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@property
def MASK_ID(self):
return self.vocab["[MASK]"]
def vocab_size(self):
return len(self.vocab)
def tokenize(self, text):
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text):
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def decode(self, token_ids, clean_up_tokenization_spaces=True):
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(token_ids)
tokens = self.convert_ids_to_tokens(tokens_ids)
out_string = ''.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN):
out_string = out_string.replace(special_tok, '')
out_string = clean_up_tokenization(out_string)
return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
......@@ -213,13 +179,10 @@ class BertTokenizer(object):
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
......@@ -230,40 +193,8 @@ class BertTokenizer(object):
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
else:
vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
class BasicTokenizer(object):
......@@ -271,17 +202,20 @@ class BasicTokenizer(object):
def __init__(self,
do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
never_split=None):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text):
def tokenize(self, text, never_split=None):
"""Tokenizes a piece of text."""
never_split = self.never_split + (never_split if never_split is not None else [])
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
......@@ -293,7 +227,7 @@ class BasicTokenizer(object):
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
if self.do_lower_case and token not in never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
......@@ -312,9 +246,9 @@ class BasicTokenizer(object):
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
......@@ -386,7 +320,7 @@ class BasicTokenizer(object):
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
......
......@@ -23,8 +23,6 @@ import os
import regex as re
from io import open
from .model_utils import clean_up_tokenization
try:
from functools import lru_cache
except ImportError:
......@@ -33,24 +31,32 @@ except ImportError:
def lru_cache():
return lambda func: func
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
},
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024,
'gpt2-medium': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
......@@ -87,71 +93,19 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(object):
class GPT2Tokenizer(PreTrainedTokenizer):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace',
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
......@@ -165,25 +119,9 @@ class GPT2Tokenizer(object):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
@property
def vocab_size(self):
return len(self.encoder)
def bpe(self, token):
if token in self.cache:
......@@ -226,7 +164,7 @@ class GPT2Tokenizer(object):
self.cache[token] = word
return word
def tokenize(self, text):
def _tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
......@@ -237,57 +175,27 @@ class GPT2Tokenizer(object):
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index, self.unk_token)
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
text = ''.join(tokens_ids)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = clean_up_tokenization(text)
return text
def save_vocabulary(self, vocab_path):
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
......@@ -303,14 +211,4 @@ class GPT2Tokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
return vocab_file, merge_file
......@@ -20,29 +20,32 @@ import json
import logging
import os
import re
import sys
from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
},
'merges_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
},
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'openai-gpt': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
def get_pairs(word):
"""
......@@ -71,73 +74,19 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class OpenAIGPTTokenizer(object):
class OpenAIGPTTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
try:
import ftfy
import spacy
......@@ -145,39 +94,19 @@ class OpenAIGPTTokenizer(object):
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True,
never_split=special_tokens if special_tokens is not None else [])
self.nlp = BasicTokenizer(do_lower_case=True)
self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens))
@property
def vocab_size(self):
return len(self.encoder)
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',)
......@@ -222,7 +151,7 @@ class OpenAIGPTTokenizer(object):
self.cache[token] = word
return word
def tokenize(self, text):
def _tokenize(self, text):
""" Tokenize a string. """
split_tokens = []
if self.fix_text is None:
......@@ -237,58 +166,26 @@ class OpenAIGPTTokenizer(object):
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def _convert_id_to_token(self, index):
"""Converts an id in a token (BPE) using the vocab."""
return self.decoder.get(index, self.unk_token)
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = clean_up_tokenization(out_string)
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
return out_string
def save_vocabulary(self, vocab_path):
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
......@@ -304,14 +201,4 @@ class OpenAIGPTTokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
return vocab_file, merge_file
......@@ -31,7 +31,7 @@ import torch
import numpy as np
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
if sys.version_info[0] == 2:
import cPickle as pickle
......@@ -41,66 +41,43 @@ else:
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = {
'pretrained_vocab_file':
{
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': 512,
}
VOCAB_NAME = 'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
}
CORPUS_NAME = 'corpus.bin'
class TransfoXLTokenizer(object):
class TransfoXLTokenizer(PreTrainedTokenizer):
"""
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a TransfoXLTokenizer.
The TransfoXLTokenizer.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if os.path.isdir(pretrained_model_name_or_path):
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
else:
vocab_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs)
vocab_dict = torch.load(resolved_vocab_file)
for key, value in vocab_dict.items():
tokenizer.__dict__[key] = value
return tokenizer
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, pretrained_vocab_file=None,
never_split=None, unk_token="<unk>", eos_token="<eos>",
additional_special_tokens=["<formula>"], **kwargs):
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
if never_split is None:
never_split = self.all_special_tokens
if special is None:
special = []
self.counter = Counter()
self.special = special
self.min_freq = min_freq
......@@ -110,6 +87,13 @@ class TransfoXLTokenizer(object):
self.vocab_file = vocab_file
self.never_split = never_split
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items():
self.__dict__[key] = value
if vocab_file is not None:
self.build_vocab()
......@@ -157,7 +141,7 @@ class TransfoXLTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file'])
torch.save(self.__dict__, vocab_file)
return (vocab_file,)
......@@ -224,11 +208,13 @@ class TransfoXLTokenizer(object):
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx):
def _convert_id_to_token(self, idx):
"""Converts an id in a token (BPE) using the vocab."""
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
return self.idx2sym[idx]
def get_idx(self, sym):
def _convert_token_to_id(self, sym):
""" Converts a token (str/unicode) in an id using the vocab. """
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
......@@ -244,36 +230,19 @@ class TransfoXLTokenizer(object):
else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def convert_ids_to_tokens(self, indices):
"""Converts a sequence of indices in symbols using the vocab."""
return [self.get_sym(idx) for idx in indices]
def convert_tokens_to_ids(self, symbols):
"""Converts a sequence of symbols into ids using the vocab."""
return [self.get_idx(sym) for sym in symbols]
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
out_string = ' '.join(tokens_ids).strip()
return out_string
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True):
"""Converts a sequence of indices in a string."""
if exclude is None:
out_string = ' '.join([self.get_sym(idx) for idx in indices])
else:
out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
if clean_up_tokenization_spaces:
out_string = clean_up_tokenization(out_string)
return out_string
def __len__(self):
@property
def vocab_size(self):
return len(self.idx2sym)
def tokenize(self, line, add_eos=False, add_double_eos=False):
def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip()
# convert to lower case
if self.lower_case:
......@@ -484,7 +453,7 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
corpus_file))
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment