test_configuration_common.py 6.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
import copy
17
import json
Aymeric Augustin's avatar
Aymeric Augustin committed
18
import os
19
import tempfile
Sylvain Gugger's avatar
Sylvain Gugger committed
20

21
from transformers import is_torch_available
22

23
from .utils.test_configuration_utils import config_common_kwargs
24
25


26
class ConfigTester(object):
Sylvain Gugger's avatar
Sylvain Gugger committed
27
    def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
28
29
        self.parent = parent
        self.config_class = config_class
NielsRogge's avatar
NielsRogge committed
30
        self.has_text_modality = has_text_modality
31
        self.inputs_dict = kwargs
Sylvain Gugger's avatar
Sylvain Gugger committed
32
        self.common_properties = common_properties
33
34
35

    def create_and_test_config_common_properties(self):
        config = self.config_class(**self.inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
38
39
40
        common_properties = (
            ["hidden_size", "num_attention_heads", "num_hidden_layers"]
            if self.common_properties is None
            else self.common_properties
        )
41
42

        # Add common fields for text models
NielsRogge's avatar
NielsRogge committed
43
        if self.has_text_modality:
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            common_properties.extend(["vocab_size"])

        # Test that config has the common properties as getters
        for prop in common_properties:
            self.parent.assertTrue(hasattr(config, prop), msg=f"`{prop}` does not exist")

        # Test that config has the common properties as setter
        for idx, name in enumerate(common_properties):
            try:
                setattr(config, name, idx)
                self.parent.assertEqual(
                    getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}"
                )
            except NotImplementedError:
                # Some models might not be able to implement setters for common_properties
                # In that case, a NotImplementedError is raised
                pass

        # Test if config class can be called with Config(prop_name=..)
        for idx, name in enumerate(common_properties):
            try:
                config = self.config_class(**{name: idx})
                self.parent.assertEqual(
                    getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}"
                )
            except NotImplementedError:
                # Some models might not be able to implement setters for common_properties
                # In that case, a NotImplementedError is raised
                pass
73
74
75
76
77
78
79
80
81

    def create_and_test_config_to_json_string(self):
        config = self.config_class(**self.inputs_dict)
        obj = json.loads(config.to_json_string())
        for key, value in self.inputs_dict.items():
            self.parent.assertEqual(obj[key], value)

    def create_and_test_config_to_json_file(self):
        config_first = self.config_class(**self.inputs_dict)
thomwolf's avatar
thomwolf committed
82

83
        with tempfile.TemporaryDirectory() as tmpdirname:
thomwolf's avatar
thomwolf committed
84
85
86
87
88
89
90
91
92
            json_file_path = os.path.join(tmpdirname, "config.json")
            config_first.to_json_file(json_file_path)
            config_second = self.config_class.from_json_file(json_file_path)

        self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())

    def create_and_test_config_from_and_save_pretrained(self):
        config_first = self.config_class(**self.inputs_dict)

93
        with tempfile.TemporaryDirectory() as tmpdirname:
thomwolf's avatar
thomwolf committed
94
95
96
            config_first.save_pretrained(tmpdirname)
            config_second = self.config_class.from_pretrained(tmpdirname)

97
98
        self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())

99
100
101
        with self.parent.assertRaises(OSError):
            self.config_class.from_pretrained(f".{tmpdirname}")

102
103
104
105
106
107
108
109
110
111
112
    def create_and_test_config_from_and_save_pretrained_subfolder(self):
        config_first = self.config_class(**self.inputs_dict)

        subfolder = "test"
        with tempfile.TemporaryDirectory() as tmpdirname:
            sub_tmpdirname = os.path.join(tmpdirname, subfolder)
            config_first.save_pretrained(sub_tmpdirname)
            config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)

        self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())

113
114
115
116
117
118
119
120
121
    def create_and_test_config_with_num_labels(self):
        config = self.config_class(**self.inputs_dict, num_labels=5)
        self.parent.assertEqual(len(config.id2label), 5)
        self.parent.assertEqual(len(config.label2id), 5)

        config.num_labels = 3
        self.parent.assertEqual(len(config.id2label), 3)
        self.parent.assertEqual(len(config.label2id), 3)

122
123
    def check_config_can_be_init_without_params(self):
        if self.config_class.is_composition:
124
125
126
127
128
            with self.parent.assertRaises(ValueError):
                config = self.config_class()
        else:
            config = self.config_class()
            self.parent.assertIsNotNone(config)
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def check_config_arguments_init(self):
        kwargs = copy.deepcopy(config_common_kwargs)
        config = self.config_class(**kwargs)
        wrong_values = []
        for key, value in config_common_kwargs.items():
            if key == "torch_dtype":
                if not is_torch_available():
                    continue
                else:
                    import torch

                    if config.torch_dtype != torch.float16:
                        wrong_values.append(("torch_dtype", config.torch_dtype, torch.float16))
            elif getattr(config, key) != value:
                wrong_values.append((key, getattr(config, key), value))

        if len(wrong_values) > 0:
            errors = "\n".join([f"- {v[0]}: got {v[1]} instead of {v[2]}" for v in wrong_values])
148
            raise ValueError(f"The following keys were not properly set in the config:\n{errors}")
149

150
151
152
153
    def run_common_tests(self):
        self.create_and_test_config_common_properties()
        self.create_and_test_config_to_json_string()
        self.create_and_test_config_to_json_file()
thomwolf's avatar
thomwolf committed
154
        self.create_and_test_config_from_and_save_pretrained()
155
        self.create_and_test_config_from_and_save_pretrained_subfolder()
156
        self.create_and_test_config_with_num_labels()
157
        self.check_config_can_be_init_without_params()
158
        self.check_config_arguments_init()