Unverified Commit 78d706f3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt tokenizer] support lowercase tokenizer (#8389)

* support lowercase tokenizer

* fix arg pos
parent 1e2acd0d
......@@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
# detect whether this is a do_lower_case situation, which can be derived by checking whether we
# have at least one upcase letter in the source vocab
do_lower_case = True
for k in src_vocab.keys():
if not k.islower():
do_lower_case = False
break
tgt_dict = Dictionary.load(tgt_dict_file)
tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
tgt_vocab_size = len(tgt_vocab)
......@@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
tokenizer_conf = {
"langs": [src_lang, tgt_lang],
"model_max_length": 1024,
"do_lower_case": do_lower_case,
}
print(f"Generating {fsmt_tokenizer_config_file}")
......
......@@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
File containing the vocabulary for the target language.
merges_file (:obj:`str`):
File containing the merges.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to lowercase the input when tokenizing.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
......@@ -186,6 +186,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
src_vocab_file=None,
tgt_vocab_file=None,
merges_file=None,
do_lower_case=False,
unk_token="<unk>",
bos_token="<s>",
sep_token="</s>",
......@@ -197,6 +198,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
src_vocab_file=src_vocab_file,
tgt_vocab_file=tgt_vocab_file,
merges_file=merges_file,
do_lower_case=do_lower_case,
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
......@@ -207,6 +209,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
self.src_vocab_file = src_vocab_file
self.tgt_vocab_file = tgt_vocab_file
self.merges_file = merges_file
self.do_lower_case = do_lower_case
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
......@@ -351,6 +354,9 @@ class FSMTTokenizer(PreTrainedTokenizer):
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
lang = self.src_lang
if self.do_lower_case:
text = text.lower()
if bypass_tokenizer:
text = text.split()
else:
......
......@@ -151,6 +151,13 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True)
self.assertEqual(decoded_text, src_text)
@slow
def test_tokenizer_lower(self):
tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en", do_lower_case=True)
tokens = tokenizer.tokenize("USA is United States of America")
expected = ["us", "a</w>", "is</w>", "un", "i", "ted</w>", "st", "ates</w>", "of</w>", "am", "er", "ica</w>"]
self.assertListEqual(tokens, expected)
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
def test_torch_encode_plus_sent_to_model(self):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment