"sgl-router/py_src/vscode:/vscode.git/clone" did not exist on "5ee777c98ff558d1acc089e162f22fb9cde1b3e0"
Unverified Commit 9f5f6464 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2211 from huggingface/fast-tokenizers

Fast tokenizers
parents 9024b199 e6ec24fa
......@@ -86,6 +86,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.0.10",
# accessing files from S3 directly
"boto3",
# filesystem locks e.g. to prevent parallel downloads
......
......@@ -103,12 +103,12 @@ from .pipelines import (
)
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer
......
......@@ -20,7 +20,9 @@ import logging
import os
import unicodedata
from .tokenization_utils import PreTrainedTokenizer
import tokenizers as tk
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
......@@ -525,3 +527,68 @@ def _is_punctuation(char):
if cat.startswith("P"):
return True
return False
class BertTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
max_length=None,
pad_to_max_length=False,
stride=0,
truncation_strategy="longest_first",
add_special_tokens=True,
**kwargs
):
super(BertTokenizerFast, self).__init__(
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self._tokenizer = tk.Tokenizer(tk.models.WordPiece.from_files(vocab_file, unk_token=unk_token))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(
tk.pre_tokenizers.BertPreTokenizer.new(
do_basic_tokenize=do_basic_tokenize,
do_lower_case=do_lower_case,
tokenize_chinese_chars=tokenize_chinese_chars,
never_split=never_split if never_split is not None else [],
)
)
self._tokenizer.with_decoder(tk.decoders.WordPiece.new())
if add_special_tokens:
self._tokenizer.with_post_processor(
tk.processors.BertProcessing.new(
(sep_token, self._tokenizer.token_to_id(sep_token)),
(cls_token, self._tokenizer.token_to_id(cls_token)),
)
)
if max_length is not None:
self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy)
self._tokenizer.with_padding(
max_length=max_length if pad_to_max_length else None,
direction=self.padding_side,
pad_id=self.pad_token_id,
pad_type_id=self.pad_token_type_id,
pad_token=self.pad_token,
)
self._decoder = tk.decoders.WordPiece.new()
......@@ -21,8 +21,9 @@ import os
from functools import lru_cache
import regex as re
import tokenizers as tk
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
......@@ -246,3 +247,42 @@ class GPT2Tokenizer(PreTrainedTokenizer):
index += 1
return vocab_file, merge_file
class GPT2TokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
merges_file,
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_to_max_length=False,
add_prefix_space=False,
max_length=None,
stride=0,
truncation_strategy="longest_first",
**kwargs
):
super(GPT2TokenizerFast, self).__init__(
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
)
self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space))
self._tokenizer.with_decoder(tk.decoders.ByteLevel.new())
if max_length:
self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy)
self._tokenizer.with_padding(
max_length=max_length if pad_to_max_length else None,
direction=self.padding_side,
pad_id=self.pad_token_id if self.pad_token_id is not None else 0,
pad_type_id=self.pad_token_type_id,
pad_token=self.pad_token if self.pad_token is not None else "",
)
self._decoder = tk.decoders.ByteLevel.new()
......@@ -1414,3 +1414,199 @@ class PreTrainedTokenizer(object):
.replace(" 're", "'re")
)
return out_string
class PreTrainedTokenizerFast(PreTrainedTokenizer):
_tokenizer = None
_decoder = None
def __init__(self, **kwargs):
super(PreTrainedTokenizerFast, self).__init__(**kwargs)
@property
def tokenizer(self):
if self._tokenizer is None:
raise NotImplementedError
return self._tokenizer
@property
def decoder(self):
if self._decoder is None:
raise NotImplementedError
return self._decoder
@property
def vocab_size(self):
return self.tokenizer.get_vocab_size(with_added_tokens=False)
def __len__(self):
return self.tokenizer.get_vocab_size(with_added_tokens=True)
@PreTrainedTokenizer.bos_token.setter
def bos_token(self, value):
self._bos_token = value
self._update_special_tokens()
@PreTrainedTokenizer.eos_token.setter
def eos_token(self, value):
self._eos_token = value
self._update_special_tokens()
@PreTrainedTokenizer.unk_token.setter
def unk_token(self, value):
self._unk_token = value
self._update_special_tokens()
@PreTrainedTokenizer.sep_token.setter
def sep_token(self, value):
self._sep_token = value
self._update_special_tokens()
@PreTrainedTokenizer.pad_token.setter
def pad_token(self, value):
self._pad_token = value
self._update_special_tokens()
@PreTrainedTokenizer.cls_token.setter
def cls_token(self, value):
self._cls_token = value
self._update_special_tokens()
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
self._mask_token = value
self._update_special_tokens()
@PreTrainedTokenizer.additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
self._update_special_tokens()
def _update_special_tokens(self):
if self._tokenizer is not None:
self._tokenizer.add_special_tokens(self.all_special_tokens)
@staticmethod
def _convert_encoding(
encoding,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
):
encoding_dict = {
"input_ids": encoding.ids,
}
if return_token_type_ids:
encoding_dict["token_type_ids"] = encoding.type_ids
if return_attention_mask:
encoding_dict["attention_mask"] = encoding.attention_mask
if return_overflowing_tokens:
overflowing = encoding.overflowing
encoding_dict["overflowing_tokens"] = overflowing.ids if overflowing is not None else []
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask
# Prepare inputs as tensors if asked
if return_tensors == "tf" and is_tf_available():
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
elif return_tensors == "pt" and is_torch_available():
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]])
elif return_tensors is not None:
logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors
)
)
return encoding_dict
def encode_plus(
self,
text,
text_pair=None,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
**kwargs
):
encoding = self.tokenizer.encode(text, text_pair)
return self._convert_encoding(
encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
)
def tokenize(self, text):
return self.tokenizer.encode(text).tokens
def _convert_token_to_id_with_added_voc(self, token):
id = self.tokenizer.token_to_id(token)
if id is None:
return self.unk_token_id
return id
def _convert_id_to_token(self, index):
return self.tokenizer.id_to_token(int(index))
def convert_tokens_to_string(self, tokens):
return self.decoder.decode(tokens)
def add_tokens(self, new_tokens):
self.tokenizer.add_tokens(new_tokens)
def add_special_tokens(self, special_tokens_dict):
added = super().add_special_tokens(special_tokens_dict)
self._update_special_tokens()
return added
def encode_batch(
self,
texts,
return_tensors=None,
return_token_type_ids=True,
return_attention_mask=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
):
return [
self._convert_encoding(
encoding,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
)
for encoding in self.tokenizer.encode_batch(texts)
]
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = self.tokenizer.decode(token_ids, skip_special_tokens)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True):
return [
self.clean_up_tokenization(text) if clear_up_tokenization_spaces else text
for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)
]
......@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
VOCAB_FILES_NAMES,
BasicTokenizer,
BertTokenizer,
BertTokenizerFast,
WordpieceTokenizer,
_is_control,
_is_punctuation,
......@@ -34,6 +35,7 @@ from .utils import slow
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BertTokenizer
test_rust_tokenizer = True
def setUp(self):
super(BertTokenizationTest, self).setUp()
......@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
......@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False)
sequence = u"UNwant\u00E9d,running"
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
def test_chinese(self):
tokenizer = BasicTokenizer()
......
......@@ -23,6 +23,7 @@ import tempfile
class TokenizerTesterMixin:
tokenizer_class = None
test_rust_tokenizer = False
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
......@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
def get_tokenizer(self, **kwargs):
raise NotImplementedError
def get_rust_tokenizer(self, **kwargs):
raise NotImplementedError
def get_input_output_texts(self):
raise NotImplementedError
......
......@@ -18,7 +18,7 @@ import json
import os
import unittest
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer
from transformers.tokenization_gpt2 import VOCAB_FILES_NAMES, GPT2Tokenizer, GPT2TokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
......@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPT2Tokenizer
test_rust_tokenizer = True
def setUp(self):
super(GPT2TokenizationTest, self).setUp()
......@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "lower newer"
output_text = "lower newer"
......@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer(add_special_tokens=False, add_prefix_space=True)
sequence = u"lower newer"
# Testing tokenization
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
# Testing conversion to ids without special tokens
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing conversion to ids with special tokens
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
ids = tokenizer.encode(sequence, add_prefix_space=True)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Testing the unknown token
input_tokens = tokens + [rust_tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment