"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "06d468d3f07955e4d6ff82feda17120e0e17d987"
Unverified Commit dbf7bfaf authored by Joel Tang's avatar Joel Tang Committed by GitHub
Browse files

Fix idx2sym not loaded from pretrained vocab file in Transformer XL (#27589)

* Load idx2sym from pretrained vocab file in Transformer XL

When loading vocab file from a pretrained tokenizer for Transformer XL,
although the pickled vocabulary file contains a idx2sym key, it isn't
loaded, because it is discarded as the empty list already exists as
an attribute.

Solution is to explicitly take it into account, just like for sym2idx.

* ran make style
parent dc68a39c
......@@ -223,7 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_dict is not None:
for key, value in vocab_dict.items():
if key not in self.__dict__ or key == "sym2idx":
if key not in self.__dict__ or key in ["sym2idx", "idx2sym"]:
self.__dict__[key] = value
elif vocab_file is not None:
self.build_vocab()
......
......@@ -15,7 +15,9 @@
import os
import pickle
import unittest
from collections import Counter, OrderedDict
from transformers.models.transfo_xl.tokenization_transfo_xl import VOCAB_FILES_NAMES, TransfoXLTokenizer
......@@ -47,6 +49,25 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
saved_dict = {
"eos_idx": 0,
"min_freq": 0,
"vocab_file": None,
"counter": Counter(["welcome home"]),
"sym2idx": OrderedDict([("<eos>", 0), ("welcome", 1), ("home", 2)]),
"delimiter": None,
"idx2sym": ["<eos>", "welcome", "home"],
"max_size": None,
"lower_case": False,
"special": ["<eos>"],
}
self.pretrained_vocab_file = os.path.join(
self.tmpdirname, "mock_folder", VOCAB_FILES_NAMES["pretrained_vocab_file"]
)
os.makedirs(os.path.dirname(self.pretrained_vocab_file), exist_ok=True)
with open(self.pretrained_vocab_file, "wb") as f:
pickle.dump(saved_dict, f)
def get_tokenizer(self, **kwargs):
kwargs["lower_case"] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
......@@ -128,3 +149,8 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check that token is moved to specified id
self.assertEqual(tokenizer.encode("new1"), [1])
self.assertEqual(tokenizer.decode([1]), "new1")
def test_from_pretrained_vocab_file(self):
tokenizer = TransfoXLTokenizer.from_pretrained(os.path.join(self.tmpdirname, "mock_folder"))
sentence = "welcome home"
self.assertEqual(tokenizer.decode(tokenizer.encode(sentence)), sentence)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment