Unverified Commit ab108a0e authored by Guillaume Klein's avatar Guillaume Klein Committed by GitHub
Browse files

Add missing lang tokens in M2M100Tokenizer.get_vocab (#18416)

parent 0bd6d934
...@@ -280,7 +280,7 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -280,7 +280,7 @@ class M2M100Tokenizer(PreTrainedTokenizer):
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def get_vocab(self) -> Dict: def get_vocab(self) -> Dict:
vocab = self.encoder.copy() vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder) vocab.update(self.added_tokens_encoder)
return vocab return vocab
......
...@@ -89,7 +89,7 @@ class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -89,7 +89,7 @@ class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(vocab_keys[0], "</s>") self.assertEqual(vocab_keys[0], "</s>")
self.assertEqual(vocab_keys[1], "<unk>") self.assertEqual(vocab_keys[1], "<unk>")
self.assertEqual(vocab_keys[-1], "<s>") self.assertEqual(vocab_keys[-1], "<s>")
self.assertEqual(len(vocab_keys), 10) self.assertEqual(len(vocab_keys), 110)
def test_vocab_size(self): def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 117) self.assertEqual(self.get_tokenizer().vocab_size, 117)
...@@ -160,6 +160,9 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): ...@@ -160,6 +160,9 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.get_lang_id("ro"), 128076) self.assertEqual(self.tokenizer.get_lang_id("ro"), 128076)
self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063) self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063)
def test_get_vocab(self):
self.assertIn(self.tokenizer.get_lang_token("en"), self.tokenizer.get_vocab())
def test_tokenizer_batch_encode_plus(self): def test_tokenizer_batch_encode_plus(self):
self.tokenizer.src_lang = "en" self.tokenizer.src_lang = "en"
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment