Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
51d7ebf2
Unverified
Commit
51d7ebf2
authored
Jan 14, 2022
by
SaulLu
Committed by
GitHub
Jan 14, 2022
Browse files
fix BertTokenizerFast `tokenize_chinese_chars` arg (#15158)
* add new test * fix in init * more relevant test
parent
4aa16fce
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
7 deletions
+46
-7
src/transformers/models/bert/tokenization_bert_fast.py
src/transformers/models/bert/tokenization_bert_fast.py
+9
-7
tests/test_tokenization_bert.py
tests/test_tokenization_bert.py
+37
-0
No files found.
src/transformers/models/bert/tokenization_bert_fast.py
View file @
51d7ebf2
...
...
@@ -188,15 +188,17 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
**
kwargs
,
)
pre_tok
_state
=
json
.
loads
(
self
.
backend_tokenizer
.
normalizer
.
__getstate__
())
normalizer
_state
=
json
.
loads
(
self
.
backend_tokenizer
.
normalizer
.
__getstate__
())
if
(
pre_tok_state
.
get
(
"lowercase"
,
do_lower_case
)
!=
do_lower_case
or
pre_tok_state
.
get
(
"strip_accents"
,
strip_accents
)
!=
strip_accents
normalizer_state
.
get
(
"lowercase"
,
do_lower_case
)
!=
do_lower_case
or
normalizer_state
.
get
(
"strip_accents"
,
strip_accents
)
!=
strip_accents
or
normalizer_state
.
get
(
"handle_chinese_chars"
,
tokenize_chinese_chars
)
!=
tokenize_chinese_chars
):
pre_tok_class
=
getattr
(
normalizers
,
pre_tok_state
.
pop
(
"type"
))
pre_tok_state
[
"lowercase"
]
=
do_lower_case
pre_tok_state
[
"strip_accents"
]
=
strip_accents
self
.
backend_tokenizer
.
normalizer
=
pre_tok_class
(
**
pre_tok_state
)
normalizer_class
=
getattr
(
normalizers
,
normalizer_state
.
pop
(
"type"
))
normalizer_state
[
"lowercase"
]
=
do_lower_case
normalizer_state
[
"strip_accents"
]
=
strip_accents
normalizer_state
[
"handle_chinese_chars"
]
=
tokenize_chinese_chars
self
.
backend_tokenizer
.
normalizer
=
normalizer_class
(
**
normalizer_state
)
self
.
do_lower_case
=
do_lower_case
...
...
tests/test_tokenization_bert.py
View file @
51d7ebf2
...
...
@@ -299,3 +299,40 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
[
e
[
1
]
for
e
in
expected_results
],
tokenizer_r
.
convert_ids_to_tokens
(
tokens
[
"input_ids"
])
)
self
.
assertEqual
([
e
[
0
]
for
e
in
expected_results
],
tokens
[
"offset_mapping"
])
def
test_change_tokenize_chinese_chars
(
self
):
list_of_commun_chinese_char
=
[
"的"
,
"人"
,
"有"
]
text_with_chinese_char
=
""
.
join
(
list_of_commun_chinese_char
)
for
tokenizer
,
pretrained_name
,
kwargs
in
self
.
tokenizers_list
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
(
{
pretrained_name
}
)"
):
kwargs
[
"tokenize_chinese_chars"
]
=
True
tokenizer_p
=
self
.
tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
tokenizer_r
=
self
.
rust_tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
ids_without_spe_char_p
=
tokenizer_p
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
ids_without_spe_char_r
=
tokenizer_r
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
tokens_without_spe_char_r
=
tokenizer_r
.
convert_ids_to_tokens
(
ids_without_spe_char_r
)
tokens_without_spe_char_p
=
tokenizer_p
.
convert_ids_to_tokens
(
ids_without_spe_char_p
)
# it is expected that each Chinese character is not preceded by "##"
self
.
assertListEqual
(
tokens_without_spe_char_p
,
list_of_commun_chinese_char
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
list_of_commun_chinese_char
)
kwargs
[
"tokenize_chinese_chars"
]
=
False
tokenizer_r
=
self
.
rust_tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
tokenizer_p
=
self
.
tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
ids_without_spe_char_r
=
tokenizer_r
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
ids_without_spe_char_p
=
tokenizer_p
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
tokens_without_spe_char_r
=
tokenizer_r
.
convert_ids_to_tokens
(
ids_without_spe_char_r
)
tokens_without_spe_char_p
=
tokenizer_p
.
convert_ids_to_tokens
(
ids_without_spe_char_p
)
# it is expected that only the first Chinese character is not preceded by "##".
expected_tokens
=
[
f
"##
{
token
}
"
if
idx
!=
0
else
token
for
idx
,
token
in
enumerate
(
list_of_commun_chinese_char
)
]
self
.
assertListEqual
(
tokens_without_spe_char_p
,
expected_tokens
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
expected_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment