configuration_utils.py 18.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""


import copy
import json
import logging
import os
23
from typing import Dict, Tuple
24

Aymeric Augustin's avatar
Aymeric Augustin committed
25
26
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url

27
28
29

logger = logging.getLogger(__name__)

30

31
32
33
34
35
36
37
38
39
class PretrainedConfig(object):
    r""" Base class for all configuration classes.
        Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.

        Note:
            A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
            It only affects the model's configuration.

        Class attributes (overridden by derived classes):
40
            - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
41

Lysandre's avatar
Lysandre committed
42
43
44
45
46
47
48
49
50
        Args:
            finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
                Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
            num_labels (:obj:`int`, `optional`, defaults to `2`):
                Number of classes to use when the model is a classification model (sequences/tokens)
            output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`):
                Should the model returns all hidden-states.
            torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Is the model used with Torchscript (for PyTorch models).
51
    """
52
    model_type: str = ""
53
54

    def __init__(self, **kwargs):
thomwolf's avatar
thomwolf committed
55
        # Attributes with defaults
56
        self.output_hidden_states = kwargs.pop("output_hidden_states", False)
57
        self.output_attentions = kwargs.pop("output_attentions", False)
58
        self.use_cache = kwargs.pop("use_cache", True)  # Not used by all models
59
60
61
        self.torchscript = kwargs.pop("torchscript", False)  # Only used by PyTorch models
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
        self.pruned_heads = kwargs.pop("pruned_heads", {})
thomwolf's avatar
thomwolf committed
62
63

        # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
Patrick von Platen's avatar
Patrick von Platen committed
64
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
65
        self.is_decoder = kwargs.pop("is_decoder", False)
66

thomwolf's avatar
thomwolf committed
67
        # Parameters for sequence generation
68
        self.max_length = kwargs.pop("max_length", 20)
Patrick von Platen's avatar
Patrick von Platen committed
69
        self.min_length = kwargs.pop("min_length", 0)
70
        self.do_sample = kwargs.pop("do_sample", False)
Patrick von Platen's avatar
Patrick von Platen committed
71
        self.early_stopping = kwargs.pop("early_stopping", False)
72
73
74
75
76
77
        self.num_beams = kwargs.pop("num_beams", 1)
        self.temperature = kwargs.pop("temperature", 1.0)
        self.top_k = kwargs.pop("top_k", 50)
        self.top_p = kwargs.pop("top_p", 1.0)
        self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        self.length_penalty = kwargs.pop("length_penalty", 1.0)
Patrick von Platen's avatar
Patrick von Platen committed
78
        self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
79
        self.bad_words_ids = kwargs.pop("bad_words_ids", None)
80
        self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
thomwolf's avatar
thomwolf committed
81

thomwolf's avatar
thomwolf committed
82
        # Fine-tuning task arguments
Julien Chaumond's avatar
Julien Chaumond committed
83
        self.architectures = kwargs.pop("architectures", None)
84
        self.finetuning_task = kwargs.pop("finetuning_task", None)
85
86
87
        self.id2label = kwargs.pop("id2label", None)
        self.label2id = kwargs.pop("label2id", None)
        if self.id2label is not None:
88
            kwargs.pop("num_labels", None)
89
90
91
92
            self.id2label = dict((int(key), value) for key, value in self.id2label.items())
            # Keys are always strings in JSON so convert ids to int here.
        else:
            self.num_labels = kwargs.pop("num_labels", 2)
thomwolf's avatar
thomwolf committed
93

94
95
96
97
98
99
100
101
102
103
        # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
        self.prefix = kwargs.pop("prefix", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
        self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

        # task specific arguments
        self.task_specific_params = kwargs.pop("task_specific_params", None)

104
105
106
        # TPU arguments
        self.xla_device = kwargs.pop("xla_device", None)

thomwolf's avatar
thomwolf committed
107
108
109
110
111
112
113
114
        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(key, value, self))
                raise err

115
116
    @property
    def num_labels(self):
117
        return len(self.id2label)
118
119
120

    @num_labels.setter
    def num_labels(self, num_labels):
121
        self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
122
123
        self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))

124
    def save_pretrained(self, save_directory):
Lysandre's avatar
Lysandre committed
125
126
127
128
129
130
131
        """
        Save a configuration object to the directory `save_directory`, so that it
        can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.

        Args:
            save_directory (:obj:`string`):
                Directory where the configuration JSON file will be saved.
132
        """
133
134
135
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
136
137
138
139

        # If we save using the predefined names, we can load using `from_pretrained`
        output_config_file = os.path.join(save_directory, CONFIG_NAME)

140
        self.to_json_file(output_config_file, use_diff=True)
thomwolf's avatar
thomwolf committed
141
        logger.info("Configuration saved in {}".format(output_config_file))
142
143

    @classmethod
Lysandre's avatar
Lysandre committed
144
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        r"""

        Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.

        Args:
            pretrained_model_name_or_path (:obj:`string`):
                either:
                  - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
                    download, e.g.: ``bert-base-uncased``.
                  - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
                    our S3, e.g.: ``dbmdz/bert-base-german-cased``.
                  - a path to a `directory` containing a configuration file saved using the
                    :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
                  - a path or url to a saved configuration JSON `file`, e.g.:
                    ``./my_model_directory/configuration.json``.
            cache_dir (:obj:`string`, `optional`):
161
162
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
Lysandre's avatar
Lysandre committed
163
164
165
166
167
168
169
            kwargs (:obj:`Dict[str, any]`, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
                controlled by the `return_unused_kwargs` keyword parameter.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
170
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
Lysandre's avatar
Lysandre committed
171
172
173
            proxies (:obj:`Dict`, `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g.:
                :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
174
175
                The proxies are used on each request.
            return_unused_kwargs: (`optional`) bool:
Lysandre's avatar
Lysandre committed
176
177
178
179
                If False, then this function returns just the final configuration object.
                If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
                dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
                of kwargs which has not been used to update `config` and is otherwise ignored.
180

Lysandre's avatar
Lysandre committed
181
182
        Returns:
            :class:`PretrainedConfig`: An instance of a configuration object
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

        Examples::

            # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
            # derived class: BertConfig
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            config = BertConfig.from_pretrained('./test/saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
            config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
            config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
            assert config.output_attention == True
            config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
                                                               foo=False, return_unused_kwargs=True)
            assert config.output_attention == True
            assert unused_kwargs == {'foo': False}

        """
Julien Chaumond's avatar
Julien Chaumond committed
199
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
200
201
202
        return cls.from_dict(config_dict, **kwargs)

    @classmethod
203
    def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
204
205
206
207
208
        """
        From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
        for instantiating a Config using `from_dict`.

        Parameters:
Lysandre's avatar
Lysandre committed
209
210
211
212
213
214
            pretrained_model_name_or_path (:obj:`string`):
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

        Returns:
            :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.

215
        """
216
217
218
219
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
220
        local_files_only = kwargs.pop("local_files_only", False)
221

222
        if os.path.isdir(pretrained_model_name_or_path):
223
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
224
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
225
            config_file = pretrained_model_name_or_path
226
        else:
Julien Chaumond's avatar
Julien Chaumond committed
227
            config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
228

229
        try:
230
            # Load from URL or cache if already cached
231
232
233
234
235
236
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
237
                local_files_only=local_files_only,
238
            )
239
            # Load config dict
240
241
            if resolved_config_file is None:
                raise EnvironmentError
242
            config_dict = cls._dict_from_json_file(resolved_config_file)
243

thomwolf's avatar
thomwolf committed
244
        except EnvironmentError:
245
246
247
248
249
            msg = (
                f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
            )
thomwolf's avatar
thomwolf committed
250
251
            raise EnvironmentError(msg)

252
        except json.JSONDecodeError:
253
254
255
256
257
            msg = (
                "Couldn't reach server at '{}' to download configuration file or "
                "configuration file is not a valid JSON file. "
                "Please check network or file content here: {}.".format(config_file, resolved_config_file)
            )
258
259
            raise EnvironmentError(msg)

260
261
262
        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
263
            logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
264

265
266
267
        return config_dict, kwargs

    @classmethod
Lysandre's avatar
Lysandre committed
268
    def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        """
        Constructs a `Config` from a Python dictionary of parameters.

        Args:
            config_dict (:obj:`Dict[str, any]`):
                Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
                from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
                method.
            kwargs (:obj:`Dict[str, any]`):
                Additional parameters from which to initialize the configuration object.

        Returns:
            :class:`PretrainedConfig`: An instance of a configuration object
        """
283
284
285
286
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        config = cls(**config_dict)

287
        if hasattr(config, "pruned_heads"):
288
            config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
289
290
291
292
293
294
295
296
297
298

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

299
        logger.info("Model config %s", str(config))
300
301
302
303
304
305
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config

    @classmethod
Lysandre's avatar
Lysandre committed
306
    def from_json_file(cls, json_file: str) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
307
308
309
310
311
312
313
314
315
316
317
        """
        Constructs a `Config` from the path to a json file of parameters.

        Args:
            json_file (:obj:`string`):
                Path to the JSON file containing the parameters.

        Returns:
            :class:`PretrainedConfig`: An instance of a configuration object

        """
318
319
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)
320
321

    @classmethod
322
    def _dict_from_json_file(cls, json_file: str):
323
        with open(json_file, "r", encoding="utf-8") as reader:
324
            text = reader.read()
325
        return json.loads(text)
326
327
328
329
330

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __repr__(self):
331
        return "{} {}".format(self.__class__.__name__, self.to_json_string())
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    def to_diff_dict(self):
        """
        Removes all attributes from config which correspond to the default
        config attributes for better readability and serializes to a Python
        dictionary.

        Returns:
            :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        config_dict = self.to_dict()

        # get the default config dict
        default_config_dict = PretrainedConfig().to_dict()

        serializable_config_dict = {}

        # only serialize values that differ from the default config
        for key, value in config_dict.items():
            if key not in default_config_dict or value != default_config_dict[key]:
                serializable_config_dict[key] = value

        return serializable_config_dict

356
    def to_dict(self):
Lysandre's avatar
Lysandre committed
357
358
359
360
361
362
        """
        Serializes this instance to a Python dictionary.

        Returns:
            :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
363
        output = copy.deepcopy(self.__dict__)
364
365
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
366
367
        return output

368
    def to_json_string(self, use_diff=True):
Lysandre's avatar
Lysandre committed
369
370
371
        """
        Serializes this instance to a JSON string.

372
373
374
375
        Args:
            use_diff (:obj:`bool`):
                If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.

Lysandre's avatar
Lysandre committed
376
377
378
        Returns:
            :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
        """
379
380
381
382
383
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
384

385
    def to_json_file(self, json_file_path, use_diff=True):
Lysandre's avatar
Lysandre committed
386
387
388
389
390
391
        """
        Save this instance to a json file.

        Args:
            json_file_path (:obj:`string`):
                Path to the JSON file in which this configuration instance's parameters will be saved.
392
393
            use_diff (:obj:`bool`):
                If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
Lysandre's avatar
Lysandre committed
394
        """
395
        with open(json_file_path, "w", encoding="utf-8") as writer:
396
            writer.write(self.to_json_string(use_diff=use_diff))
397
398
399
400
401
402
403
404
405
406
407

    def update(self, config_dict: Dict):
        """
        Updates attributes of this class
        with attributes from `config_dict`.

        Args:
            :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
        """
        for key, value in config_dict.items():
            setattr(self, key, value)