"Plugson/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "33b958e11280f632c5b8007ca8b50f97fe9f9551"
Commit 7044ed6b authored by thomwolf's avatar thomwolf
Browse files

fix tokenizers serialization

parent cd65c41a
...@@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname) return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
......
...@@ -67,13 +67,13 @@ class CommonTestCases: ...@@ -67,13 +67,13 @@ class CommonTestCases:
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42) self.assertEqual(tokenizer.max_len, 42)
tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
self.assertEqual(tokenizer.max_len, 43) self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self): def test_pickle_tokenizer(self):
......
...@@ -95,6 +95,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -95,6 +95,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
# in a library like ours, at all. # in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file) vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items(): for key, value in vocab_dict.items():
if key not in self.__dict__:
self.__dict__[key] = value self.__dict__[key] = value
if vocab_file is not None: if vocab_file is not None:
......
...@@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, max_len=None, def __init__(self, vocab_file,
do_lower_case=False, remove_space=True, keep_accents=False, do_lower_case=False, remove_space=True, keep_accents=False,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>", bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>", pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment