Unverified Commit 78d706f3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt tokenizer] support lowercase tokenizer (#8389)

* support lowercase tokenizer

* fix arg pos
parent 1e2acd0d
...@@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -133,6 +133,14 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
with open(src_vocab_file, "w", encoding="utf-8") as f: with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent)) f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
# detect whether this is a do_lower_case situation, which can be derived by checking whether we
# have at least one upcase letter in the source vocab
do_lower_case = True
for k in src_vocab.keys():
if not k.islower():
do_lower_case = False
break
tgt_dict = Dictionary.load(tgt_dict_file) tgt_dict = Dictionary.load(tgt_dict_file)
tgt_vocab = rewrite_dict_keys(tgt_dict.indices) tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
tgt_vocab_size = len(tgt_vocab) tgt_vocab_size = len(tgt_vocab)
...@@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -207,6 +215,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
tokenizer_conf = { tokenizer_conf = {
"langs": [src_lang, tgt_lang], "langs": [src_lang, tgt_lang],
"model_max_length": 1024, "model_max_length": 1024,
"do_lower_case": do_lower_case,
} }
print(f"Generating {fsmt_tokenizer_config_file}") print(f"Generating {fsmt_tokenizer_config_file}")
......
...@@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -154,7 +154,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
File containing the vocabulary for the target language. File containing the vocabulary for the target language.
merges_file (:obj:`str`): merges_file (:obj:`str`):
File containing the merges. File containing the merges.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to lowercase the input when tokenizing. Whether or not to lowercase the input when tokenizing.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`): unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
...@@ -186,6 +186,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -186,6 +186,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
src_vocab_file=None, src_vocab_file=None,
tgt_vocab_file=None, tgt_vocab_file=None,
merges_file=None, merges_file=None,
do_lower_case=False,
unk_token="<unk>", unk_token="<unk>",
bos_token="<s>", bos_token="<s>",
sep_token="</s>", sep_token="</s>",
...@@ -197,6 +198,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -197,6 +198,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
src_vocab_file=src_vocab_file, src_vocab_file=src_vocab_file,
tgt_vocab_file=tgt_vocab_file, tgt_vocab_file=tgt_vocab_file,
merges_file=merges_file, merges_file=merges_file,
do_lower_case=do_lower_case,
unk_token=unk_token, unk_token=unk_token,
bos_token=bos_token, bos_token=bos_token,
sep_token=sep_token, sep_token=sep_token,
...@@ -207,6 +209,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -207,6 +209,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
self.src_vocab_file = src_vocab_file self.src_vocab_file = src_vocab_file
self.tgt_vocab_file = tgt_vocab_file self.tgt_vocab_file = tgt_vocab_file
self.merges_file = merges_file self.merges_file = merges_file
self.do_lower_case = do_lower_case
# cache of sm.MosesPunctNormalizer instance # cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict() self.cache_moses_punct_normalizer = dict()
...@@ -351,6 +354,9 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -351,6 +354,9 @@ class FSMTTokenizer(PreTrainedTokenizer):
# raise ValueError(f"Expected lang={self.src_lang}, but got {lang}") # raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
lang = self.src_lang lang = self.src_lang
if self.do_lower_case:
text = text.lower()
if bypass_tokenizer: if bypass_tokenizer:
text = text.split() text = text.split()
else: else:
......
...@@ -151,6 +151,13 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -151,6 +151,13 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True) decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True)
self.assertEqual(decoded_text, src_text) self.assertEqual(decoded_text, src_text)
@slow
def test_tokenizer_lower(self):
tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-ru-en", do_lower_case=True)
tokens = tokenizer.tokenize("USA is United States of America")
expected = ["us", "a</w>", "is</w>", "un", "i", "ted</w>", "st", "ates</w>", "of</w>", "am", "er", "ica</w>"]
self.assertListEqual(tokens, expected)
@unittest.skip("FSMTConfig.__init__ requires non-optional args") @unittest.skip("FSMTConfig.__init__ requires non-optional args")
def test_torch_encode_plus_sent_to_model(self): def test_torch_encode_plus_sent_to_model(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment