"vscode:/vscode.git/clone" did not exist on "d44db1145cc87f6092a8701ff6b9c6a18077e292"
Commit d3418a94 authored by thomwolf's avatar thomwolf
Browse files

update tests

parent 56e98ba8
......@@ -16,15 +16,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import shutil
import json
import random
import uuid
import tempfile
import unittest
import logging
from .tokenization_tests_commons import TemporaryDirectory
class ConfigTester(object):
......@@ -48,16 +45,28 @@ class ConfigTester(object):
def create_and_test_config_to_json_file(self):
config_first = self.config_class(**self.inputs_dict)
json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
with TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "config.json")
config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path)
os.remove(json_file_path)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_from_and_save_pretrained(self):
config_first = self.config_class(**self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
config_first.save_pretrained(tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_common_tests(self):
self.create_and_test_config_common_properties()
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
......@@ -15,10 +15,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
import json
import tempfile
import shutil
import unittest
from transformers.model_card import ModelCard
......@@ -50,10 +47,6 @@ class ModelCardTester(unittest.TestCase):
'ROUGE-1': 76,
},
}
self.tmpdirname = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_model_card_common_properties(self):
model_card = ModelCard.from_dict(self.inputs_dict)
......@@ -83,5 +76,14 @@ class ModelCardTester(unittest.TestCase):
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
def test_model_card_from_and_save_pretrained(self):
model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
model_card_first.save_pretrained(tmpdirname)
model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment