Commit 36bca545 authored by thomwolf's avatar thomwolf
Browse files

tokenization abstract class - tests for examples

parent a4f98054
......@@ -20,7 +20,7 @@ import json
import shutil
import pytest
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer
from.tokenization_tests_commons import create_and_check_tokenizer_commons
......@@ -58,14 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
......@@ -20,7 +20,7 @@ from io import open
import shutil
import pytest
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
from.tokenization_tests_commons import create_and_check_tokenizer_commons
......@@ -59,13 +59,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from pytorch_transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]:
tokenizer = tokenizer_class.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
if __name__ == "__main__":
unittest.main()
......@@ -20,9 +20,9 @@ import json
import shutil
import pytest
from pytorch_transformers.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_xlm import XLMTokenizer
from.tokenization_tests_commons import create_and_check_tokenizer_commons
from .tokenization_tests_commons import create_and_check_tokenizer_commons
class XLMTokenizationTest(unittest.TestCase):
......@@ -57,14 +57,6 @@ class XLMTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = XLMTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
......@@ -19,9 +19,7 @@ import unittest
import shutil
import pytest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer,
PRETRAINED_VOCAB_ARCHIVE_MAP,
SPIECE_UNDERLINE)
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from.tokenization_tests_commons import create_and_check_tokenizer_commons
......@@ -60,14 +58,6 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
......
......@@ -23,11 +23,15 @@ import unicodedata
from io import open
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
......@@ -41,8 +45,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
......@@ -57,7 +62,6 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad': 512,
'bert-base-cased-finetuned-mrpc': 512,
}
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
......@@ -83,8 +87,11 @@ def whitespace_tokenize(text):
return tokens
class BertTokenizer(object):
class BertTokenizer(PreTrainedTokenizer):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
......@@ -203,7 +210,7 @@ class BertTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
......@@ -215,13 +222,10 @@ class BertTokenizer(object):
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
......@@ -232,40 +236,8 @@ class BertTokenizer(object):
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
else:
vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
class BasicTokenizer(object):
......
......@@ -23,8 +23,6 @@ import os
import regex as re
from io import open
from .model_utils import clean_up_tokenization
try:
from functools import lru_cache
except ImportError:
......@@ -33,24 +31,38 @@ except ImportError:
def lru_cache():
return lambda func: func
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
'special_tokens_file': 'special_tokens.txt'
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
},
'special_tokens_file':
{
'gpt2': None,
'gpt2-medium': None,
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024,
'gpt2-medium': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
......@@ -87,70 +99,16 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(object):
class GPT2Tokenizer(PreTrainedTokenizer):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, errors='replace', max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
......@@ -165,9 +123,16 @@ class GPT2Tokenizer(object):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
all_special_tokens = []
if special_tokens_file is not None:
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
all_special_tokens.extend(special_tokens_to_add)
if special_tokens is not None and special_tokens:
all_special_tokens.extend(special_tokens)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
self.set_special_tokens(all_special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
......@@ -285,9 +250,9 @@ class GPT2Tokenizer(object):
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
......
......@@ -26,23 +26,35 @@ from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
'special_tokens_file': 'special_tokens.txt'
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
},
'merges_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
},
'special_tokens_file':
{
'openai-gpt': None,
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'openai-gpt': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
def get_pairs(word):
"""
......@@ -71,7 +83,7 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class OpenAIGPTTokenizer(object):
class OpenAIGPTTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer. Peculiarities:
- lower case all inputs
......@@ -79,65 +91,11 @@ class OpenAIGPTTokenizer(object):
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None):
try:
import ftfy
import spacy
......@@ -156,9 +114,17 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
all_special_tokens = []
if special_tokens_file is not None:
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
all_special_tokens.extend(special_tokens_to_add)
if special_tokens is not None and special_tokens:
all_special_tokens.extend(special_tokens)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
self.set_special_tokens(all_special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
......@@ -286,9 +252,9 @@ class OpenAIGPTTokenizer(object):
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
......
......@@ -31,7 +31,7 @@ import torch
import numpy as np
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
if sys.version_info[0] == 2:
import cPickle as pickle
......@@ -41,66 +41,35 @@ else:
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin'}
PRETRAINED_VOCAB_FILES_MAP = {
'pretrained_vocab_file':
{
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': 512,
}
VOCAB_NAME = 'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
}
CORPUS_NAME = 'corpus.bin'
class TransfoXLTokenizer(object):
class TransfoXLTokenizer(PreTrainedTokenizer):
"""
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a TransfoXLTokenizer.
The TransfoXLTokenizer.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
if os.path.isdir(pretrained_model_name_or_path):
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
else:
vocab_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs)
vocab_dict = torch.load(resolved_vocab_file)
for key, value in vocab_dict.items():
tokenizer.__dict__[key] = value
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
delimiter=None, vocab_file=None, pretrained_vocab_file=None,
never_split=("<unk>", "<eos>", "<formula>")):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
......@@ -110,6 +79,13 @@ class TransfoXLTokenizer(object):
self.vocab_file = vocab_file
self.never_split = never_split
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items():
self.__dict__[key] = value
if vocab_file is not None:
self.build_vocab()
......@@ -157,7 +133,7 @@ class TransfoXLTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file'])
torch.save(self.__dict__, vocab_file)
return (vocab_file,)
......@@ -484,7 +460,7 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
corpus_file))
return None
......
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
from .file_utils import cached_path
logger = logging.getLogger(__name__)
class PreTrainedTokenizer(object):
""" An abstract class to handle dowloading and loading pretrained tokenizers.
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
max_model_input_sizes = {}
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
return cls._from_pretrained(*inputs, **kwargs)
@classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed.
"""
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
if pretrained_model_name_or_path in s3_models:
for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
else:
for file_id, file_name in cls.vocab_files_names.items():
if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
else:
full_file_name = pretrained_model_name_or_path
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We don't load it.".format(full_file_name))
full_file_name = None
vocab_files[file_id] = full_file_name
# redirect to the cache, if necessary
try:
resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path, str(vocab_files.keys())))
return None
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
else:
logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id]))
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(*inputs, **resolved_vocab_files, **kwargs)
return tokenizer
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
......@@ -26,30 +26,42 @@ from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
'special_tokens_file': 'special_tokens.txt'
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
},
'merges_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
},
'special_tokens_file':
{
'xlm-mlm-en-2048': None,
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-en-2048': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
INDEX= {
"bos_index": 0,
"eos_index": 1,
"pad_index": 2,
"unk_index": 3,
"mask_index": 5
INDEX = {
"bos_index": 0,
"eos_index": 1,
"pad_index": 2,
"unk_index": 3,
"mask_index": 5
}
def get_pairs(word):
......@@ -79,7 +91,7 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class XLMTokenizer(object):
class XLMTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
- lower case all inputs
......@@ -87,65 +99,11 @@ class XLMTokenizer(object):
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None):
try:
import ftfy
import spacy
......@@ -164,9 +122,17 @@ class XLMTokenizer(object):
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
all_special_tokens = []
if special_tokens_file is not None:
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
all_special_tokens.extend(special_tokens_to_add)
if special_tokens is not None and special_tokens:
all_special_tokens.extend(special_tokens)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
self.set_special_tokens(all_special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
......@@ -294,9 +260,9 @@ class XLMTokenizer(object):
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
......
......@@ -27,15 +27,24 @@ import unicodedata
import six
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-large-cased': 512,
}
VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
SPIECE_UNDERLINE = u'▁'
......@@ -46,7 +55,7 @@ SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class XLNetTokenizer(object):
class XLNetTokenizer(PreTrainedTokenizer):
"""
SentencePiece based tokenizer. Peculiarities:
- requires SentencePiece: https://github.com/google/sentencepiece
......@@ -63,64 +72,11 @@ class XLNetTokenizer(object):
"<eod>" : 7,
"<eop>" : 8,
}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {}"
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, special_tokens=None, max_len=None,
def __init__(self, vocab_file, max_len=None,
do_lower_case=False, remove_space=True, keep_accents=False):
try:
import sentencepiece as spm
......@@ -136,9 +92,6 @@ class XLNetTokenizer(object):
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
@property
def UNK_TOKEN(self):
......@@ -181,7 +134,7 @@ class XLNetTokenizer(object):
return self.special_symbols["<mask>"]
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
return len(self.sp_model)
def __getstate__(self):
state = self.__dict__.copy()
......@@ -198,19 +151,6 @@ class XLNetTokenizer(object):
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.sp_model) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens: %s", str(self.special_tokens))
def preprocess_text(self, inputs):
if self.remove_space:
outputs = ' '.join(inputs.strip().split())
......@@ -272,15 +212,9 @@ class XLNetTokenizer(object):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.sp_model.PieceToId(tokens)
return self.sp_model.PieceToId(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.sp_model.PieceToId(token))
ids.append(self.sp_model.PieceToId(token))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
......@@ -289,15 +223,11 @@ class XLNetTokenizer(object):
)
return ids
def convert_ids_to_tokens(self, ids, return_unicode=True, skip_special_tokens=False):
def convert_ids_to_tokens(self, ids, return_unicode=True):
"""Converts a sequence of ids in tokens."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.sp_model.IdToPiece(i))
tokens.append(self.sp_model.IdToPiece(i))
if six.PY2 and return_unicode:
ret_pieces = []
......@@ -311,9 +241,9 @@ class XLNetTokenizer(object):
def encode(self, text, sample=False):
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
def decode(self, ids, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
tokens = self.convert_ids_to_tokens(ids)
out_string = ''.join(tokens)
if clean_up_tokenization_spaces:
out_string = out_string.strip().replace('<unk>', '')
......@@ -328,18 +258,7 @@ class XLNetTokenizer(object):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
copyfile(self.vocab_file, out_vocab_file)
index = len(self.sp_model)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return out_vocab_file, special_tokens_file
return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment