Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a12b969
Unverified
Commit
9a12b969
authored
Dec 21, 2020
by
Patrick von Platen
Committed by
GitHub
Dec 21, 2020
Browse files
[MPNet] Add slow to fast tokenizer converter (#9233)
* add converter * delet unnecessary comments
parent
f4432b7e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
25 deletions
+44
-25
src/transformers/convert_slow_tokenizer.py
src/transformers/convert_slow_tokenizer.py
+40
-24
tests/test_tokenization_mpnet.py
tests/test_tokenization_mpnet.py
+4
-1
No files found.
src/transformers/convert_slow_tokenizer.py
View file @
9a12b969
...
@@ -74,18 +74,6 @@ class BertConverter(Converter):
...
@@ -74,18 +74,6 @@ class BertConverter(Converter):
vocab
=
self
.
original_tokenizer
.
vocab
vocab
=
self
.
original_tokenizer
.
vocab
tokenizer
=
Tokenizer
(
WordPiece
(
vocab
,
unk_token
=
str
(
self
.
original_tokenizer
.
unk_token
)))
tokenizer
=
Tokenizer
(
WordPiece
(
vocab
,
unk_token
=
str
(
self
.
original_tokenizer
.
unk_token
)))
# # Let the tokenizer know about special tokens if they are part of the vocab
# if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
tokenize_chinese_chars
=
False
tokenize_chinese_chars
=
False
strip_accents
=
False
strip_accents
=
False
do_lower_case
=
False
do_lower_case
=
False
...
@@ -125,18 +113,6 @@ class FunnelConverter(Converter):
...
@@ -125,18 +113,6 @@ class FunnelConverter(Converter):
vocab
=
self
.
original_tokenizer
.
vocab
vocab
=
self
.
original_tokenizer
.
vocab
tokenizer
=
Tokenizer
(
WordPiece
(
vocab
,
unk_token
=
str
(
self
.
original_tokenizer
.
unk_token
)))
tokenizer
=
Tokenizer
(
WordPiece
(
vocab
,
unk_token
=
str
(
self
.
original_tokenizer
.
unk_token
)))
# # Let the tokenizer know about special tokens if they are part of the vocab
# if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
tokenize_chinese_chars
=
False
tokenize_chinese_chars
=
False
strip_accents
=
False
strip_accents
=
False
do_lower_case
=
False
do_lower_case
=
False
...
@@ -171,6 +147,45 @@ class FunnelConverter(Converter):
...
@@ -171,6 +147,45 @@ class FunnelConverter(Converter):
return
tokenizer
return
tokenizer
class
MPNetConverter
(
Converter
):
def
converted
(
self
)
->
Tokenizer
:
vocab
=
self
.
original_tokenizer
.
vocab
tokenizer
=
Tokenizer
(
WordPiece
(
vocab
,
unk_token
=
str
(
self
.
original_tokenizer
.
unk_token
)))
tokenize_chinese_chars
=
False
strip_accents
=
False
do_lower_case
=
False
if
hasattr
(
self
.
original_tokenizer
,
"basic_tokenizer"
):
tokenize_chinese_chars
=
self
.
original_tokenizer
.
basic_tokenizer
.
tokenize_chinese_chars
strip_accents
=
self
.
original_tokenizer
.
basic_tokenizer
.
strip_accents
do_lower_case
=
self
.
original_tokenizer
.
basic_tokenizer
.
do_lower_case
tokenizer
.
normalizer
=
normalizers
.
BertNormalizer
(
clean_text
=
True
,
handle_chinese_chars
=
tokenize_chinese_chars
,
strip_accents
=
strip_accents
,
lowercase
=
do_lower_case
,
)
tokenizer
.
pre_tokenizer
=
pre_tokenizers
.
BertPreTokenizer
()
cls
=
str
(
self
.
original_tokenizer
.
cls_token
)
sep
=
str
(
self
.
original_tokenizer
.
sep_token
)
cls_token_id
=
self
.
original_tokenizer
.
cls_token_id
sep_token_id
=
self
.
original_tokenizer
.
sep_token_id
tokenizer
.
post_processor
=
processors
.
TemplateProcessing
(
single
=
f
"
{
cls
}
:0 $A:0
{
sep
}
:0"
,
pair
=
f
"
{
cls
}
:0 $A:0
{
sep
}
:0
{
sep
}
:0 $B:1
{
sep
}
:1"
,
# MPNet uses two [SEP] tokens
special_tokens
=
[
(
cls
,
cls_token_id
),
(
sep
,
sep_token_id
),
],
)
tokenizer
.
decoder
=
decoders
.
WordPiece
(
prefix
=
"##"
)
return
tokenizer
class
OpenAIGPTConverter
(
Converter
):
class
OpenAIGPTConverter
(
Converter
):
def
converted
(
self
)
->
Tokenizer
:
def
converted
(
self
)
->
Tokenizer
:
vocab
=
self
.
original_tokenizer
.
encoder
vocab
=
self
.
original_tokenizer
.
encoder
...
@@ -602,6 +617,7 @@ SLOW_TO_FAST_CONVERTERS = {
...
@@ -602,6 +617,7 @@ SLOW_TO_FAST_CONVERTERS = {
"LongformerTokenizer"
:
RobertaConverter
,
"LongformerTokenizer"
:
RobertaConverter
,
"LxmertTokenizer"
:
BertConverter
,
"LxmertTokenizer"
:
BertConverter
,
"MBartTokenizer"
:
MBartConverter
,
"MBartTokenizer"
:
MBartConverter
,
"MPNetTokenizer"
:
MPNetConverter
,
"MobileBertTokenizer"
:
BertConverter
,
"MobileBertTokenizer"
:
BertConverter
,
"OpenAIGPTTokenizer"
:
OpenAIGPTConverter
,
"OpenAIGPTTokenizer"
:
OpenAIGPTConverter
,
"PegasusTokenizer"
:
PegasusConverter
,
"PegasusTokenizer"
:
PegasusConverter
,
...
...
tests/test_tokenization_mpnet.py
View file @
9a12b969
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
os
import
os
import
unittest
import
unittest
from
transformers
import
MPNetTokenizerFast
from
transformers.models.mpnet.tokenization_mpnet
import
VOCAB_FILES_NAMES
,
MPNetTokenizer
from
transformers.models.mpnet.tokenization_mpnet
import
VOCAB_FILES_NAMES
,
MPNetTokenizer
from
transformers.testing_utils
import
require_tokenizers
,
slow
from
transformers.testing_utils
import
require_tokenizers
,
slow
...
@@ -27,7 +28,9 @@ from .test_tokenization_common import TokenizerTesterMixin
...
@@ -27,7 +28,9 @@ from .test_tokenization_common import TokenizerTesterMixin
class
MPNetTokenizerTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
class
MPNetTokenizerTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
MPNetTokenizer
tokenizer_class
=
MPNetTokenizer
test_rust_tokenizer
=
False
rust_tokenizer_class
=
MPNetTokenizerFast
test_rust_tokenizer
=
True
space_between_special_tokens
=
True
def
setUp
(
self
):
def
setUp
(
self
):
super
().
setUp
()
super
().
setUp
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment