Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
78d706f3
Unverified
Commit
78d706f3
authored
Nov 09, 2020
by
Stas Bekman
Committed by
GitHub
Nov 09, 2020
Browse files
[fsmt tokenizer] support lowercase tokenizer (#8389)
* support lowercase tokenizer * fix arg pos
parent
1e2acd0d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
1 deletion
+23
-1
src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
...rs/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
+9
-0
src/transformers/tokenization_fsmt.py
src/transformers/tokenization_fsmt.py
+7
-1
tests/test_tokenization_fsmt.py
tests/test_tokenization_fsmt.py
+7
-0
No files found.
src/transformers/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
View file @
78d706f3
...
...
@@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
with
open
(
src_vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
src_vocab
,
ensure_ascii
=
False
,
indent
=
json_indent
))
# detect whether this is a do_lower_case situation, which can be derived by checking whether we
# have at least one upcase letter in the source vocab
do_lower_case
=
True
for
k
in
src_vocab
.
keys
():
if
not
k
.
islower
():
do_lower_case
=
False
break
tgt_dict
=
Dictionary
.
load
(
tgt_dict_file
)
tgt_vocab
=
rewrite_dict_keys
(
tgt_dict
.
indices
)
tgt_vocab_size
=
len
(
tgt_vocab
)
...
...
@@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
tokenizer_conf
=
{
"langs"
:
[
src_lang
,
tgt_lang
],
"model_max_length"
:
1024
,
"do_lower_case"
:
do_lower_case
,
}
print
(
f
"Generating
{
fsmt_tokenizer_config_file
}
"
)
...
...
src/transformers/tokenization_fsmt.py
View file @
78d706f3
...
...
@@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
File containing the vocabulary for the target language.
merges_file (:obj:`str`):
File containing the merges.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`
Tru
e`):
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`
Fals
e`):
Whether or not to lowercase the input when tokenizing.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
...
...
@@ -186,6 +186,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
src_vocab_file
=
None
,
tgt_vocab_file
=
None
,
merges_file
=
None
,
do_lower_case
=
False
,
unk_token
=
"<unk>"
,
bos_token
=
"<s>"
,
sep_token
=
"</s>"
,
...
...
@@ -197,6 +198,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
src_vocab_file
=
src_vocab_file
,
tgt_vocab_file
=
tgt_vocab_file
,
merges_file
=
merges_file
,
do_lower_case
=
do_lower_case
,
unk_token
=
unk_token
,
bos_token
=
bos_token
,
sep_token
=
sep_token
,
...
...
@@ -207,6 +209,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
self
.
src_vocab_file
=
src_vocab_file
self
.
tgt_vocab_file
=
tgt_vocab_file
self
.
merges_file
=
merges_file
self
.
do_lower_case
=
do_lower_case
# cache of sm.MosesPunctNormalizer instance
self
.
cache_moses_punct_normalizer
=
dict
()
...
...
@@ -351,6 +354,9 @@ class FSMTTokenizer(PreTrainedTokenizer):
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
lang
=
self
.
src_lang
if
self
.
do_lower_case
:
text
=
text
.
lower
()
if
bypass_tokenizer
:
text
=
text
.
split
()
else
:
...
...
tests/test_tokenization_fsmt.py
View file @
78d706f3
...
...
@@ -151,6 +151,13 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
decoded_text
=
tokenizer_dec
.
decode
(
encoded_ids
,
skip_special_tokens
=
True
)
self
.
assertEqual
(
decoded_text
,
src_text
)
@
slow
def
test_tokenizer_lower
(
self
):
tokenizer
=
FSMTTokenizer
.
from_pretrained
(
"facebook/wmt19-ru-en"
,
do_lower_case
=
True
)
tokens
=
tokenizer
.
tokenize
(
"USA is United States of America"
)
expected
=
[
"us"
,
"a</w>"
,
"is</w>"
,
"un"
,
"i"
,
"ted</w>"
,
"st"
,
"ates</w>"
,
"of</w>"
,
"am"
,
"er"
,
"ica</w>"
]
self
.
assertListEqual
(
tokens
,
expected
)
@
unittest
.
skip
(
"FSMTConfig.__init__ requires non-optional args"
)
def
test_torch_encode_plus_sent_to_model
(
self
):
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment