Unverified Commit dbf7bfaf authored by Joel Tang's avatar Joel Tang Committed by GitHub
Browse files

Fix idx2sym not loaded from pretrained vocab file in Transformer XL (#27589)

* Load idx2sym from pretrained vocab file in Transformer XL

When loading vocab file from a pretrained tokenizer for Transformer XL,
although the pickled vocabulary file contains a idx2sym key, it isn't
loaded, because it is discarded as the empty list already exists as
an attribute.

Solution is to explicitly take it into account, just like for sym2idx.

* ran make style
parent dc68a39c
......@@ -223,7 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_dict is not None:
for key, value in vocab_dict.items():
if key not in self.__dict__ or key == "sym2idx":
if key not in self.__dict__ or key in ["sym2idx", "idx2sym"]:
self.__dict__[key] = value
elif vocab_file is not None:
self.build_vocab()
......
......@@ -15,7 +15,9 @@
import os
import pickle
import unittest
from collections import Counter, OrderedDict
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
......@@ -47,6 +49,25 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
saved_dict = {
"eos_idx": 0,
"min_freq": 0,
"vocab_file": None,
"counter": Counter(["welcome home"]),
"sym2idx": OrderedDict([("<eos>", 0), ("welcome", 1), ("home", 2)]),
"delimiter": None,
"idx2sym": ["<eos>", "welcome", "home"],
"max_size": None,
"lower_case": False,
"special": ["<eos>"],
}
self.pretrained_vocab_file = os.path.join(
self.tmpdirname, "mock_folder", VOCAB_FILES_NAMES["pretrained_vocab_file"]
)
os.makedirs(os.path.dirname(self.pretrained_vocab_file), exist_ok=True)
with open(self.pretrained_vocab_file, "wb") as f:
pickle.dump(saved_dict, f)
def get_tokenizer(self, **kwargs):
kwargs["lower_case"] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
......@@ -128,3 +149,8 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check that token is moved to specified id
self.assertEqual(tokenizer.encode("new1"), [1])
self.assertEqual(tokenizer.decode([1]), "new1")
def test_from_pretrained_vocab_file(self):
tokenizer = TransfoXLTokenizer.from_pretrained(os.path.join(self.tmpdirname, "mock_folder"))
sentence = "welcome home"
self.assertEqual(tokenizer.decode(tokenizer.encode(sentence)), sentence)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment