Unverified Commit d490b5d5 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Fast Tokenizers save pretrained should return the list of generated file paths. (#2918)



* Correctly return the tuple of generated file(s) when calling save_pretrained
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Quality and format.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 2708b44e
......@@ -1822,4 +1822,5 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
folder, file = save_directory, self.vocab_files_names["vocab_file"]
else:
folder, file = os.path.split(os.path.abspath(save_directory))
self._tokenizer.save(folder, file)
return tuple(self._tokenizer.save(folder, file))
......@@ -236,6 +236,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
@require_torch
def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
......@@ -272,6 +275,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
......@@ -308,6 +314,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
......@@ -343,6 +352,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
......@@ -378,6 +390,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
......@@ -413,6 +428,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment