Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
51d7ebf2
Unverified
Commit
51d7ebf2
authored
Jan 14, 2022
by
SaulLu
Committed by
GitHub
Jan 14, 2022
Browse files
fix BertTokenizerFast `tokenize_chinese_chars` arg (#15158)
* add new test * fix in init * more relevant test
parent
4aa16fce
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
7 deletions
+46
-7
src/transformers/models/bert/tokenization_bert_fast.py
src/transformers/models/bert/tokenization_bert_fast.py
+9
-7
tests/test_tokenization_bert.py
tests/test_tokenization_bert.py
+37
-0
No files found.
src/transformers/models/bert/tokenization_bert_fast.py
View file @
51d7ebf2
...
@@ -188,15 +188,17 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
...
@@ -188,15 +188,17 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
**
kwargs
,
**
kwargs
,
)
)
pre_tok
_state
=
json
.
loads
(
self
.
backend_tokenizer
.
normalizer
.
__getstate__
())
normalizer
_state
=
json
.
loads
(
self
.
backend_tokenizer
.
normalizer
.
__getstate__
())
if
(
if
(
pre_tok_state
.
get
(
"lowercase"
,
do_lower_case
)
!=
do_lower_case
normalizer_state
.
get
(
"lowercase"
,
do_lower_case
)
!=
do_lower_case
or
pre_tok_state
.
get
(
"strip_accents"
,
strip_accents
)
!=
strip_accents
or
normalizer_state
.
get
(
"strip_accents"
,
strip_accents
)
!=
strip_accents
or
normalizer_state
.
get
(
"handle_chinese_chars"
,
tokenize_chinese_chars
)
!=
tokenize_chinese_chars
):
):
pre_tok_class
=
getattr
(
normalizers
,
pre_tok_state
.
pop
(
"type"
))
normalizer_class
=
getattr
(
normalizers
,
normalizer_state
.
pop
(
"type"
))
pre_tok_state
[
"lowercase"
]
=
do_lower_case
normalizer_state
[
"lowercase"
]
=
do_lower_case
pre_tok_state
[
"strip_accents"
]
=
strip_accents
normalizer_state
[
"strip_accents"
]
=
strip_accents
self
.
backend_tokenizer
.
normalizer
=
pre_tok_class
(
**
pre_tok_state
)
normalizer_state
[
"handle_chinese_chars"
]
=
tokenize_chinese_chars
self
.
backend_tokenizer
.
normalizer
=
normalizer_class
(
**
normalizer_state
)
self
.
do_lower_case
=
do_lower_case
self
.
do_lower_case
=
do_lower_case
...
...
tests/test_tokenization_bert.py
View file @
51d7ebf2
...
@@ -299,3 +299,40 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -299,3 +299,40 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
[
e
[
1
]
for
e
in
expected_results
],
tokenizer_r
.
convert_ids_to_tokens
(
tokens
[
"input_ids"
])
[
e
[
1
]
for
e
in
expected_results
],
tokenizer_r
.
convert_ids_to_tokens
(
tokens
[
"input_ids"
])
)
)
self
.
assertEqual
([
e
[
0
]
for
e
in
expected_results
],
tokens
[
"offset_mapping"
])
self
.
assertEqual
([
e
[
0
]
for
e
in
expected_results
],
tokens
[
"offset_mapping"
])
def
test_change_tokenize_chinese_chars
(
self
):
list_of_commun_chinese_char
=
[
"的"
,
"人"
,
"有"
]
text_with_chinese_char
=
""
.
join
(
list_of_commun_chinese_char
)
for
tokenizer
,
pretrained_name
,
kwargs
in
self
.
tokenizers_list
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
(
{
pretrained_name
}
)"
):
kwargs
[
"tokenize_chinese_chars"
]
=
True
tokenizer_p
=
self
.
tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
tokenizer_r
=
self
.
rust_tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
ids_without_spe_char_p
=
tokenizer_p
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
ids_without_spe_char_r
=
tokenizer_r
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
tokens_without_spe_char_r
=
tokenizer_r
.
convert_ids_to_tokens
(
ids_without_spe_char_r
)
tokens_without_spe_char_p
=
tokenizer_p
.
convert_ids_to_tokens
(
ids_without_spe_char_p
)
# it is expected that each Chinese character is not preceded by "##"
self
.
assertListEqual
(
tokens_without_spe_char_p
,
list_of_commun_chinese_char
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
list_of_commun_chinese_char
)
kwargs
[
"tokenize_chinese_chars"
]
=
False
tokenizer_r
=
self
.
rust_tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
tokenizer_p
=
self
.
tokenizer_class
.
from_pretrained
(
pretrained_name
,
**
kwargs
)
ids_without_spe_char_r
=
tokenizer_r
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
ids_without_spe_char_p
=
tokenizer_p
.
encode
(
text_with_chinese_char
,
add_special_tokens
=
False
)
tokens_without_spe_char_r
=
tokenizer_r
.
convert_ids_to_tokens
(
ids_without_spe_char_r
)
tokens_without_spe_char_p
=
tokenizer_p
.
convert_ids_to_tokens
(
ids_without_spe_char_p
)
# it is expected that only the first Chinese character is not preceded by "##".
expected_tokens
=
[
f
"##
{
token
}
"
if
idx
!=
0
else
token
for
idx
,
token
in
enumerate
(
list_of_commun_chinese_char
)
]
self
.
assertListEqual
(
tokens_without_spe_char_p
,
expected_tokens
)
self
.
assertListEqual
(
tokens_without_spe_char_r
,
expected_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment