"vscode:/vscode.git/clone" did not exist on "2b07b9e5ee14ac37fcef7bac958963d869b3b79a"
Unverified Commit a52d56c8 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into cleanup-configs

parents 8ade2040 e92bcb7e
...@@ -67,6 +67,5 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -67,6 +67,5 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# coding=utf-8
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from transformers.tokenization_t5 import (T5Tokenizer)
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model')
class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = T5Tokenizer
def setUp(self):
super(T5TokenizationTest, self).setUp()
# We have a SentencePiece fixture for testing
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = u"This is a test"
output_text = u"This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
if __name__ == '__main__':
unittest.main()
...@@ -232,6 +232,15 @@ class CommonTestCases: ...@@ -232,6 +232,15 @@ class CommonTestCases:
self.assertNotEqual(len(tokens_2), 0) self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, (str, unicode)) self.assertIsInstance(text_2, (str, unicode))
def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer()
new_toks = ['[ABC]', '[DEF]', 'GHI IHG']
tokenizer.add_tokens(new_toks)
input = "[ABC] [DEF] [ABC] GHI IHG [DEF]"
encoded = tokenizer.encode(input, add_special_tokens=False)
decoded = tokenizer.decode(encoded)
self.assertEqual(decoded, input)
def test_pretrained_model_lists(self): def test_pretrained_model_lists(self):
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
......
...@@ -30,6 +30,7 @@ from .tokenization_roberta import RobertaTokenizer ...@@ -30,6 +30,7 @@ from .tokenization_roberta import RobertaTokenizer
from .tokenization_distilbert import DistilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
from .tokenization_albert import AlbertTokenizer from .tokenization_albert import AlbertTokenizer
from .tokenization_t5 import T5Tokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -44,6 +45,7 @@ class AutoTokenizer(object): ...@@ -44,6 +45,7 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model) - contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model) - contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `camembert`: CamembertTokenizer (CamemBERT model)
...@@ -69,6 +71,7 @@ class AutoTokenizer(object): ...@@ -69,6 +71,7 @@ class AutoTokenizer(object):
The tokenizer class to instantiate is selected as the first pattern matching The tokenizer class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order): in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model) - contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model) - contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `camembert`: CamembertTokenizer (CamemBERT model)
...@@ -119,7 +122,9 @@ class AutoTokenizer(object): ...@@ -119,7 +122,9 @@ class AutoTokenizer(object):
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
""" """
if 'distilbert' in pretrained_model_name_or_path: if 't5' in pretrained_model_name_or_path:
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif 'albert' in pretrained_model_name_or_path:
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
......
# coding=utf-8
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model T5."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import re
import six
from shutil import copyfile
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
SPIECE_UNDERLINE = u'▁'
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances
####################################################
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
####################################################
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
}
}
####################################################
# Mapping from model shortcut names to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
't5-small': 512,
't5-base': 512,
't5-large': 512,
't5-3b': 512,
't5-11b': 512,
}
class T5Tokenizer(PreTrainedTokenizer):
"""
SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
- `extra_ids` add a number of extra ids added to the end of the vocabulary for use as sentinels.
These tokens are accessible as `<extra_id_{%d}>` where `{%d}` is a number between 0 and extra_ids-1.
Extra tokens are indexed from the end of the vocabulary up to beginnning (<extra_id_0> is the last token in the vocabulary)
(like in T5 preprocessing
see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>",
pad_token="<pad>", extra_ids=100, additional_special_tokens=None, **kwargs):
# Add extra_ids to the special token list
if extra_ids > 0:
if additional_special_tokens is None:
additional_special_tokens = []
additional_special_tokens.extend([u"<extra_id_{}>".format(i) for i in range(extra_ids)])
super(T5Tokenizer, self).__init__(eos_token=eos_token, unk_token=unk_token,
pad_token=pad_token, additional_special_tokens=additional_special_tokens,
**kwargs)
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use T5Tokenizer:"
"https://github.com/google/sentencepiece"
"pip install sentencepiece")
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def _tokenize(self, text, return_unicode=True, sample=False):
""" Take as input a string and return a list of strings (tokens) for words/sub-words
"""
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
# convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
pieces = ret_pieces
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
if token.startswith(u"<extra_id_"):
l = re.match(r'<extra_id_(\d+)>', token)
num = int(l.group(1))
return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index, return_unicode=True):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
else:
token = u"<extra_id_{}>".format(self.vocab_size - 1 - index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = self.sp_model.decode_pieces(tokens)
return out_string
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
...@@ -637,9 +637,11 @@ class PreTrainedTokenizer(object): ...@@ -637,9 +637,11 @@ class PreTrainedTokenizer(object):
text: The sequence to be encoded. text: The sequence to be encoded.
**kwargs: passed to the child `self.tokenize()` method **kwargs: passed to the child `self.tokenize()` method
""" """
all_special_tokens = self.all_special_tokens
def lowercase_text(t): def lowercase_text(t):
# convert non-special tokens to lowercase # convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \ pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
r'(.+?)' r'(.+?)'
return re.sub( return re.sub(
...@@ -680,17 +682,17 @@ class PreTrainedTokenizer(object): ...@@ -680,17 +682,17 @@ class PreTrainedTokenizer(object):
tokenized_text = [] tokenized_text = []
for sub_text in text_list: for sub_text in text_list:
if sub_text not in self.added_tokens_encoder \ if sub_text not in self.added_tokens_encoder \
and sub_text not in self.all_special_tokens: and sub_text not in all_special_tokens:
tokenized_text += split_on_token(tok, sub_text) tokenized_text += split_on_token(tok, sub_text)
else: else:
tokenized_text += [sub_text] tokenized_text += [sub_text]
text_list = tokenized_text text_list = tokenized_text
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \ return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \ in self.added_tokens_encoder and token not in all_special_tokens \
else [token] for token in tokenized_text))) else [token] for token in tokenized_text)))
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
tokenized_text = split_on_tokens(added_tokens, text) tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text return tokenized_text
...@@ -1178,12 +1180,12 @@ class PreTrainedTokenizer(object): ...@@ -1178,12 +1180,12 @@ class PreTrainedTokenizer(object):
if current_sub_text: if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = [] current_sub_text = []
sub_texts.append(" " + token) sub_texts.append(token)
else: else:
current_sub_text.append(token) current_sub_text.append(token)
if current_sub_text: if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ''.join(sub_texts) text = ' '.join(sub_texts)
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text) clean_text = self.clean_up_tokenization(text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment