Commit d3418a94 authored by thomwolf's avatar thomwolf
Browse files

update tests

parent 56e98ba8
...@@ -16,15 +16,12 @@ from __future__ import absolute_import ...@@ -16,15 +16,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy
import os import os
import shutil
import json import json
import random import tempfile
import uuid
import unittest import unittest
import logging from .tokenization_tests_commons import TemporaryDirectory
class ConfigTester(object): class ConfigTester(object):
...@@ -48,16 +45,28 @@ class ConfigTester(object): ...@@ -48,16 +45,28 @@ class ConfigTester(object):
def create_and_test_config_to_json_file(self): def create_and_test_config_to_json_file(self):
config_first = self.config_class(**self.inputs_dict) config_first = self.config_class(**self.inputs_dict)
json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
with TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "config.json")
config_first.to_json_file(json_file_path) config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path) config_second = self.config_class.from_json_file(json_file_path)
os.remove(json_file_path)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_from_and_save_pretrained(self):
config_first = self.config_class(**self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
config_first.save_pretrained(tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_common_tests(self): def run_common_tests(self):
self.create_and_test_config_common_properties() self.create_and_test_config_common_properties()
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
\ No newline at end of file
...@@ -15,10 +15,7 @@ ...@@ -15,10 +15,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import sys
import json import json
import tempfile
import shutil
import unittest import unittest
from transformers.model_card import ModelCard from transformers.model_card import ModelCard
...@@ -50,10 +47,6 @@ class ModelCardTester(unittest.TestCase): ...@@ -50,10 +47,6 @@ class ModelCardTester(unittest.TestCase):
'ROUGE-1': 76, 'ROUGE-1': 76,
}, },
} }
self.tmpdirname = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_model_card_common_properties(self): def test_model_card_common_properties(self):
model_card = ModelCard.from_dict(self.inputs_dict) model_card = ModelCard.from_dict(self.inputs_dict)
...@@ -83,5 +76,14 @@ class ModelCardTester(unittest.TestCase): ...@@ -83,5 +76,14 @@ class ModelCardTester(unittest.TestCase):
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
def test_model_card_from_and_save_pretrained(self):
model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
model_card_first.save_pretrained(tmpdirname)
model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment