"...lm-evaluation-harness.git" did not exist on "90d818daa915199d32e833975d60671ff4b5b451"
test_configuration_common.py 7.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import json
Aymeric Augustin's avatar
Aymeric Augustin committed
18
import os
19
import tempfile
Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
import unittest

from huggingface_hub import HfApi
from requests.exceptions import HTTPError
24
from transformers import BertConfig, GPT2Config
Sylvain Gugger's avatar
Sylvain Gugger committed
25
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
26
27
28


class ConfigTester(object):
NielsRogge's avatar
NielsRogge committed
29
    def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
30
31
        self.parent = parent
        self.config_class = config_class
NielsRogge's avatar
NielsRogge committed
32
        self.has_text_modality = has_text_modality
33
34
35
36
        self.inputs_dict = kwargs

    def create_and_test_config_common_properties(self):
        config = self.config_class(**self.inputs_dict)
37
38
39
        common_properties = ["hidden_size", "num_attention_heads", "num_hidden_layers"]

        # Add common fields for text models
NielsRogge's avatar
NielsRogge committed
40
        if self.has_text_modality:
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
            common_properties.extend(["vocab_size"])

        # Test that config has the common properties as getters
        for prop in common_properties:
            self.parent.assertTrue(hasattr(config, prop), msg=f"`{prop}` does not exist")

        # Test that config has the common properties as setter
        for idx, name in enumerate(common_properties):
            try:
                setattr(config, name, idx)
                self.parent.assertEqual(
                    getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}"
                )
            except NotImplementedError:
                # Some models might not be able to implement setters for common_properties
                # In that case, a NotImplementedError is raised
                pass

        # Test if config class can be called with Config(prop_name=..)
        for idx, name in enumerate(common_properties):
            try:
                config = self.config_class(**{name: idx})
                self.parent.assertEqual(
                    getattr(config, name), idx, msg=f"`{name} value {idx} expected, but was {getattr(config, name)}"
                )
            except NotImplementedError:
                # Some models might not be able to implement setters for common_properties
                # In that case, a NotImplementedError is raised
                pass
70
71
72
73
74
75
76
77
78

    def create_and_test_config_to_json_string(self):
        config = self.config_class(**self.inputs_dict)
        obj = json.loads(config.to_json_string())
        for key, value in self.inputs_dict.items():
            self.parent.assertEqual(obj[key], value)

    def create_and_test_config_to_json_file(self):
        config_first = self.config_class(**self.inputs_dict)
thomwolf's avatar
thomwolf committed
79

80
        with tempfile.TemporaryDirectory() as tmpdirname:
thomwolf's avatar
thomwolf committed
81
82
83
84
85
86
87
88
89
            json_file_path = os.path.join(tmpdirname, "config.json")
            config_first.to_json_file(json_file_path)
            config_second = self.config_class.from_json_file(json_file_path)

        self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())

    def create_and_test_config_from_and_save_pretrained(self):
        config_first = self.config_class(**self.inputs_dict)

90
        with tempfile.TemporaryDirectory() as tmpdirname:
thomwolf's avatar
thomwolf committed
91
92
93
            config_first.save_pretrained(tmpdirname)
            config_second = self.config_class.from_pretrained(tmpdirname)

94
95
        self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())

96
97
98
99
100
101
102
103
104
    def create_and_test_config_with_num_labels(self):
        config = self.config_class(**self.inputs_dict, num_labels=5)
        self.parent.assertEqual(len(config.id2label), 5)
        self.parent.assertEqual(len(config.label2id), 5)

        config.num_labels = 3
        self.parent.assertEqual(len(config.id2label), 3)
        self.parent.assertEqual(len(config.label2id), 3)

105
106
107
108
109
110
    def check_config_can_be_init_without_params(self):
        if self.config_class.is_composition:
            return
        config = self.config_class()
        self.parent.assertIsNotNone(config)

111
112
113
114
    def run_common_tests(self):
        self.create_and_test_config_common_properties()
        self.create_and_test_config_to_json_string()
        self.create_and_test_config_to_json_file()
thomwolf's avatar
thomwolf committed
115
        self.create_and_test_config_from_and_save_pretrained()
116
        self.create_and_test_config_with_num_labels()
117
        self.check_config_can_be_init_without_params()
Sylvain Gugger's avatar
Sylvain Gugger committed
118
119
120
121
122
123
124
125
126
127
128
129


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._api = HfApi(endpoint=ENDPOINT_STAGING)
        cls._token = cls._api.login(username=USER, password=PASS)

    @classmethod
    def tearDownClass(cls):
        try:
130
            cls._api.delete_repo(token=cls._token, name="test-config")
Sylvain Gugger's avatar
Sylvain Gugger committed
131
132
133
134
        except HTTPError:
            pass

        try:
135
            cls._api.delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
Sylvain Gugger's avatar
Sylvain Gugger committed
136
137
138
139
140
141
142
143
        except HTTPError:
            pass

    def test_push_to_hub(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        with tempfile.TemporaryDirectory() as tmp_dir:
144
            config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)
Sylvain Gugger's avatar
Sylvain Gugger committed
145

146
            new_config = BertConfig.from_pretrained(f"{USER}/test-config")
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
149
150
151
152
153
154
155
156
157
            for k, v in config.__dict__.items():
                if k != "transformers_version":
                    self.assertEqual(v, getattr(new_config, k))

    def test_push_to_hub_in_organization(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )

        with tempfile.TemporaryDirectory() as tmp_dir:
            config.save_pretrained(
158
                os.path.join(tmp_dir, "test-config-org"),
Sylvain Gugger's avatar
Sylvain Gugger committed
159
160
161
162
163
                push_to_hub=True,
                use_auth_token=self._token,
                organization="valid_org",
            )

164
            new_config = BertConfig.from_pretrained("valid_org/test-config-org")
Sylvain Gugger's avatar
Sylvain Gugger committed
165
166
167
            for k, v in config.__dict__.items():
                if k != "transformers_version":
                    self.assertEqual(v, getattr(new_config, k))
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185


class ConfigTestUtils(unittest.TestCase):
    def test_config_from_string(self):
        c = GPT2Config()

        # attempt to modify each of int/float/bool/str config records and verify they were updated
        n_embd = c.n_embd + 1  # int
        resid_pdrop = c.resid_pdrop + 1.0  # float
        scale_attn_weights = not c.scale_attn_weights  # bool
        summary_type = c.summary_type + "foo"  # str
        c.update_from_string(
            f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
        )
        self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
        self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
        self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
        self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")