"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a316a6aaa8fcfe9ff0004b122078313f0eae0631"
Unverified Commit d490b5d5 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Fast Tokenizers save pretrained should return the list of generated file paths. (#2918)



* Correctly return the tuple of generated file(s) when calling save_pretrained
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Quality and format.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 2708b44e
...@@ -1822,4 +1822,5 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): ...@@ -1822,4 +1822,5 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
folder, file = save_directory, self.vocab_files_names["vocab_file"] folder, file = save_directory, self.vocab_files_names["vocab_file"]
else: else:
folder, file = os.path.split(os.path.abspath(save_directory)) folder, file = os.path.split(os.path.abspath(save_directory))
self._tokenizer.save(folder, file)
return tuple(self._tokenizer.save(folder, file))
...@@ -236,6 +236,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -236,6 +236,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens # Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
@require_torch @require_torch
def test_transfoxl(self): def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys(): for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
...@@ -272,6 +275,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -272,6 +275,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens # Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_distilbert(self): def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name) tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
...@@ -308,6 +314,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -308,6 +314,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens # Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_gpt2(self): def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name) tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
...@@ -343,6 +352,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -343,6 +352,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens # Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_roberta(self): def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name) tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
...@@ -378,6 +390,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -378,6 +390,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens # Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_openai(self): def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name) tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
...@@ -413,6 +428,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -413,6 +428,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens # Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment