Unverified Commit 9a12b969 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[MPNet] Add slow to fast tokenizer converter (#9233)

* add converter

* delet unnecessary comments
parent f4432b7e
...@@ -74,18 +74,6 @@ class BertConverter(Converter): ...@@ -74,18 +74,6 @@ class BertConverter(Converter):
vocab = self.original_tokenizer.vocab vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
# # Let the tokenizer know about special tokens if they are part of the vocab
# if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
tokenize_chinese_chars = False tokenize_chinese_chars = False
strip_accents = False strip_accents = False
do_lower_case = False do_lower_case = False
...@@ -125,18 +113,6 @@ class FunnelConverter(Converter): ...@@ -125,18 +113,6 @@ class FunnelConverter(Converter):
vocab = self.original_tokenizer.vocab vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
# # Let the tokenizer know about special tokens if they are part of the vocab
# if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
tokenize_chinese_chars = False tokenize_chinese_chars = False
strip_accents = False strip_accents = False
do_lower_case = False do_lower_case = False
...@@ -171,6 +147,45 @@ class FunnelConverter(Converter): ...@@ -171,6 +147,45 @@ class FunnelConverter(Converter):
return tokenizer return tokenizer
class MPNetConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class OpenAIGPTConverter(Converter): class OpenAIGPTConverter(Converter):
def converted(self) -> Tokenizer: def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder vocab = self.original_tokenizer.encoder
...@@ -602,6 +617,7 @@ SLOW_TO_FAST_CONVERTERS = { ...@@ -602,6 +617,7 @@ SLOW_TO_FAST_CONVERTERS = {
"LongformerTokenizer": RobertaConverter, "LongformerTokenizer": RobertaConverter,
"LxmertTokenizer": BertConverter, "LxmertTokenizer": BertConverter,
"MBartTokenizer": MBartConverter, "MBartTokenizer": MBartConverter,
"MPNetTokenizer": MPNetConverter,
"MobileBertTokenizer": BertConverter, "MobileBertTokenizer": BertConverter,
"OpenAIGPTTokenizer": OpenAIGPTConverter, "OpenAIGPTTokenizer": OpenAIGPTConverter,
"PegasusTokenizer": PegasusConverter, "PegasusTokenizer": PegasusConverter,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import os import os
import unittest import unittest
from transformers import MPNetTokenizerFast
from transformers.models.mpnet.tokenization_mpnet import VOCAB_FILES_NAMES, MPNetTokenizer from transformers.models.mpnet.tokenization_mpnet import VOCAB_FILES_NAMES, MPNetTokenizer
from transformers.testing_utils import require_tokenizers, slow from transformers.testing_utils import require_tokenizers, slow
...@@ -27,7 +28,9 @@ from .test_tokenization_common import TokenizerTesterMixin ...@@ -27,7 +28,9 @@ from .test_tokenization_common import TokenizerTesterMixin
class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MPNetTokenizer tokenizer_class = MPNetTokenizer
test_rust_tokenizer = False rust_tokenizer_class = MPNetTokenizerFast
test_rust_tokenizer = True
space_between_special_tokens = True
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment