configuration_utils.py 28.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""


import copy
import json
import os
22
from typing import Any, Dict, Tuple
23

Aymeric Augustin's avatar
Aymeric Augustin committed
24
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
Lysandre Debut's avatar
Lysandre Debut committed
25
from .utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
26

27

Lysandre Debut's avatar
Lysandre Debut committed
28
logger = logging.get_logger(__name__)
29

30

31
class PretrainedConfig(object):
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
34
    r"""
    Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
    methods for loading/downloading/saving configurations.
Lysandre's avatar
Lysandre committed
35

Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
    Note: A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
    initialize a model does **not** load the model weights. It only affects the model's configuration.
Lysandre's avatar
Lysandre committed
38
39

    Class attributes (overridden by derived classes)
Sylvain Gugger's avatar
Sylvain Gugger committed
40

Lysandre's avatar
Lysandre committed
41
42
        - **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
          recreate the correct object in :class:`~transformers.AutoConfig`.
Sylvain Gugger's avatar
Sylvain Gugger committed
43
44
45
        - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case
          the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig`
          like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`.
Lysandre's avatar
Lysandre committed
46
47

    Args:
48
        name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
Sylvain Gugger's avatar
Sylvain Gugger committed
49
50
51
            Store the string that was passed to :func:`~transformers.PreTrainedModel.from_pretrained` or
            :func:`~transformers.TFPreTrainedModel.from_pretrained` as ``pretrained_model_name_or_path`` if the
            configuration was created with such a method.
Lysandre's avatar
Lysandre committed
52
53
54
55
56
57
58
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the model should return all hidden-states.
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the model should returns all attentions.
        use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not the model should return the last key/values attentions (not used by all models).
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
59
60
            Whether or not the model should return a :class:`~transformers.file_utils.ModelOutput` instead of a plain
            tuple.
Lysandre's avatar
Lysandre committed
61
62
63
64
65
        is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether the model is used as an encoder/decoder or not.
        is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether the model is used as decoder or not (in which case it's used as an encoder).
        add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
66
67
68
            Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
            that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which
            consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
Lysandre's avatar
Lysandre committed
69
        tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
            Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
            and decoder model to have the exact same parameter names.
Lysandre's avatar
Lysandre committed
72
        prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Sylvain Gugger's avatar
Sylvain Gugger committed
73
74
            Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
            heads to prune in said layer.
Lysandre's avatar
Lysandre committed
75

Sylvain Gugger's avatar
Sylvain Gugger committed
76
            For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
Lysandre's avatar
Lysandre committed
77
78
79
        xla_device (:obj:`bool`, `optional`):
            A flag to indicate if TPU are available or not.
        chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
Sylvain Gugger's avatar
Sylvain Gugger committed
80
81
82
83
            The chunk size of all feed forward layers in the residual attention blocks. A chunk size of :obj:`0` means
            that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes
            :obj:`n` < sequence_length embeddings at a time. For more information on feed forward chunking, see `How
            does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
Lysandre's avatar
Lysandre committed
84
85

    Parameters for sequence generation
Sylvain Gugger's avatar
Sylvain Gugger committed
86
87
88
89
90
91
92
93
94
95
96
97

        - **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by default in the
          :obj:`generate` method of the model.
        - **min_length** (:obj:`int`, `optional`, defaults to 10) -- Minimum length that will be used by default in the
          :obj:`generate` method of the model.
        - **do_sample** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by default in the
          :obj:`generate` method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
        - **early_stopping** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by default
          in the :obj:`generate` method of the model. Whether to stop the beam search when at least ``num_beams``
          sentences are finished per batch or not.
        - **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by
          default in the :obj:`generate` method of the model. 1 means no beam search.
Lysandre's avatar
Lysandre committed
98
99
100
        - **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
          probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
          positive.
Sylvain Gugger's avatar
Sylvain Gugger committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        - **top_k** (:obj:`int`, `optional`, defaults to 50) -- Number of highest probability vocabulary tokens to keep
          for top-k-filtering that will be used by default in the :obj:`generate` method of the model.
        - **top_p** (:obj:`float`, `optional`, defaults to 1) -- Value that will be used by default in the
          :obj:`generate` method of the model for ``top_p``. If set to float < 1, only the most probable tokens with
          probabilities that add up to ``top_p`` or higher are kept for generation.
        - **repetition_penalty** (:obj:`float`, `optional`, defaults to 1) -- Parameter for repetition penalty that
          will be used by default in the :obj:`generate` method of the model. 1.0 means no penalty.
        - **length_penalty** (:obj:`float`, `optional`, defaults to 1) -- Exponential penalty to the length that will
          be used by default in the :obj:`generate` method of the model.
        - **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the
          :obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size
          can only occur once.
        - **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated
          that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the
          words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word,
          add_prefix_space=True)`.
        - **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned
          sequences for each element in the batch that will be used by default in the :obj:`generate` method of the
          model.
Lysandre's avatar
Lysandre committed
120
121

    Parameters for fine-tuning tasks
Sylvain Gugger's avatar
Sylvain Gugger committed
122
123
124

        - **architectures** (:obj:`List[str]`, `optional`) -- Model architectures that can be used with the model
          pretrained weights.
Lysandre's avatar
Lysandre committed
125
126
        - **finetuning_task** (:obj:`str`, `optional`) -- Name of the task used to fine-tune the model. This can be
          used when converting from an original (TensorFlow or PyTorch) checkpoint.
Sylvain Gugger's avatar
Sylvain Gugger committed
127
128
        - **id2label** (:obj:`Dict[int, str]`, `optional`) -- A map from index (for instance prediction index, or
          target index) to label.
Lysandre's avatar
Lysandre committed
129
130
131
        - **label2id** (:obj:`Dict[str, int]`, `optional`) -- A map from label to index for the model.
        - **num_labels** (:obj:`int`, `optional`) -- Number of labels to use in the last layer added to the model,
          typically for a classification task.
Sylvain Gugger's avatar
Sylvain Gugger committed
132
133
        - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
          current task.
Lysandre's avatar
Lysandre committed
134
135

    Parameters linked to the tokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
136
137
138

        - **prefix** (:obj:`str`, `optional`) -- A specific prompt that should be added at the beginning of each text
          before calling the model.
Lysandre's avatar
Lysandre committed
139
140
141
        - **bos_token_id** (:obj:`int`, `optional`)) -- The id of the `beginning-of-stream` token.
        - **pad_token_id** (:obj:`int`, `optional`)) -- The id of the `padding` token.
        - **eos_token_id** (:obj:`int`, `optional`)) -- The id of the `end-of-stream` token.
Sylvain Gugger's avatar
Sylvain Gugger committed
142
143
        - **decoder_start_token_id** (:obj:`int`, `optional`)) -- If an encoder-decoder model starts decoding with a
          different token than `bos`, the id of that token.
144
        - **sep_token_id** (:obj:`int`, `optional`)) -- The id of the `separation` token.
Lysandre's avatar
Lysandre committed
145
146

    PyTorch specific parameters
Sylvain Gugger's avatar
Sylvain Gugger committed
147

Lysandre's avatar
Lysandre committed
148
149
        - **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
          used with Torchscript.
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
152
        - **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
          output word embeddings should be tied. Note that this is only relevant if the model has a output word
          embedding layer.
Lysandre's avatar
Lysandre committed
153
154

    TensorFlow specific parameters
Sylvain Gugger's avatar
Sylvain Gugger committed
155
156
157

        - **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
          BFloat16 scalars (only used by some TensorFlow models).
158
    """
159
    model_type: str = ""
160
    is_composition: bool = False
161
162

    def __init__(self, **kwargs):
thomwolf's avatar
thomwolf committed
163
        # Attributes with defaults
164
        self.return_dict = kwargs.pop("return_dict", False)
165
        self.output_hidden_states = kwargs.pop("output_hidden_states", False)
166
        self.output_attentions = kwargs.pop("output_attentions", False)
167
        self.use_cache = kwargs.pop("use_cache", True)  # Not used by all models
168
169
170
        self.torchscript = kwargs.pop("torchscript", False)  # Only used by PyTorch models
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
        self.pruned_heads = kwargs.pop("pruned_heads", {})
171
172
173
        self.tie_word_embeddings = kwargs.pop(
            "tie_word_embeddings", True
        )  # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
thomwolf's avatar
thomwolf committed
174
175

        # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
Patrick von Platen's avatar
Patrick von Platen committed
176
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
177
        self.is_decoder = kwargs.pop("is_decoder", False)
178
        self.add_cross_attention = kwargs.pop("add_cross_attention", False)
179
        self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
180

thomwolf's avatar
thomwolf committed
181
        # Parameters for sequence generation
182
        self.max_length = kwargs.pop("max_length", 20)
Patrick von Platen's avatar
Patrick von Platen committed
183
        self.min_length = kwargs.pop("min_length", 0)
184
        self.do_sample = kwargs.pop("do_sample", False)
Patrick von Platen's avatar
Patrick von Platen committed
185
        self.early_stopping = kwargs.pop("early_stopping", False)
186
187
188
189
190
191
        self.num_beams = kwargs.pop("num_beams", 1)
        self.temperature = kwargs.pop("temperature", 1.0)
        self.top_k = kwargs.pop("top_k", 50)
        self.top_p = kwargs.pop("top_p", 1.0)
        self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        self.length_penalty = kwargs.pop("length_penalty", 1.0)
Patrick von Platen's avatar
Patrick von Platen committed
192
        self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
193
        self.bad_words_ids = kwargs.pop("bad_words_ids", None)
194
        self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Pradhy729's avatar
Pradhy729 committed
195
        self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
thomwolf's avatar
thomwolf committed
196

thomwolf's avatar
thomwolf committed
197
        # Fine-tuning task arguments
Julien Chaumond's avatar
Julien Chaumond committed
198
        self.architectures = kwargs.pop("architectures", None)
199
        self.finetuning_task = kwargs.pop("finetuning_task", None)
200
201
202
        self.id2label = kwargs.pop("id2label", None)
        self.label2id = kwargs.pop("label2id", None)
        if self.id2label is not None:
203
            kwargs.pop("num_labels", None)
204
205
206
207
            self.id2label = dict((int(key), value) for key, value in self.id2label.items())
            # Keys are always strings in JSON so convert ids to int here.
        else:
            self.num_labels = kwargs.pop("num_labels", 2)
thomwolf's avatar
thomwolf committed
208

209
        # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
210
        self.tokenizer_class = kwargs.pop("tokenizer_class", None)
211
212
213
214
        self.prefix = kwargs.pop("prefix", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
215
216
        self.sep_token_id = kwargs.pop("sep_token_id", None)

217
218
219
220
221
        self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

        # task specific arguments
        self.task_specific_params = kwargs.pop("task_specific_params", None)

222
223
224
        # TPU arguments
        self.xla_device = kwargs.pop("xla_device", None)

225
226
227
        # Name or path to the pretrained checkpoint
        self._name_or_path = str(kwargs.pop("name_or_path", ""))

thomwolf's avatar
thomwolf committed
228
229
230
231
232
233
234
235
        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(key, value, self))
                raise err

236
237
238
239
240
241
242
243
    @property
    def name_or_path(self) -> str:
        return self._name_or_path

    @name_or_path.setter
    def name_or_path(self, value):
        self._name_or_path = str(value)  # Make sure that name_or_path is a string (for JSON encoding)

244
    @property
245
    def use_return_dict(self) -> bool:
246
        """
247
        :obj:`bool`: Whether or not return :class:`~transformers.file_utils.ModelOutput` instead of tuples.
248
        """
249
250
        # If torchscript is set, force `return_dict=False` to avoid jit errors
        return self.return_dict and not self.torchscript
251

252
    @property
253
    def num_labels(self) -> int:
254
255
256
        """
        :obj:`int`: The number of labels for classification models.
        """
257
        return len(self.id2label)
258
259

    @num_labels.setter
260
    def num_labels(self, num_labels: int):
261
        self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
262
263
        self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))

264
    def save_pretrained(self, save_directory: str):
Lysandre's avatar
Lysandre committed
265
        """
266
267
        Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
        :func:`~transformers.PretrainedConfig.from_pretrained` class method.
Lysandre's avatar
Lysandre committed
268
269

        Args:
270
271
            save_directory (:obj:`str`):
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
272
        """
273
274
275
        if os.path.isfile(save_directory):
            raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
        os.makedirs(save_directory, exist_ok=True)
276
277
278
        # If we save using the predefined names, we can load using `from_pretrained`
        output_config_file = os.path.join(save_directory, CONFIG_NAME)

279
        self.to_json_file(output_config_file, use_diff=True)
thomwolf's avatar
thomwolf committed
280
        logger.info("Configuration saved in {}".format(output_config_file))
281
282

    @classmethod
283
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
284
        r"""
285
286
        Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
        configuration.
Lysandre's avatar
Lysandre committed
287
288

        Args:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            pretrained_model_name_or_path (:obj:`str`):
                This can be either:

                - the `shortcut name` of a pretrained model configuration to load from cache or download, e.g.,
                  ``bert-base-uncased``.
                - the `identifier name` of a pretrained model configuration that was uploaded to our S3 by any user,
                  e.g., ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing a configuration file saved using the
                  :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
                - a path or url to a saved configuration JSON `file`, e.g.,
                  ``./my_model_directory/configuration.json``.
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
Lysandre's avatar
Lysandre committed
303
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
304
305
                Whether or not to force to (re-)download the configuration files and override the cached versions if
                they exist.
Lysandre's avatar
Lysandre committed
306
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
307
308
309
                Whether or not to delete incompletely received file. Attempts to resume the download if such a file
                exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
310
311
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
312
313
314
315
316
317
318
319
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
            kwargs (:obj:`Dict[str, Any]`, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
320
321
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the ``return_unused_kwargs`` keyword parameter.
322

Lysandre's avatar
Lysandre committed
323
        Returns:
324
            :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
325
326
327
328
329
330
331
332

        Examples::

            # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
            # derived class: BertConfig
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            config = BertConfig.from_pretrained('./test/saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
            config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
333
334
335
            config = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
            assert config.output_attentions == True
            config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True,
336
                                                               foo=False, return_unused_kwargs=True)
337
            assert config.output_attentions == True
338
339
340
            assert unused_kwargs == {'foo': False}

        """
Julien Chaumond's avatar
Julien Chaumond committed
341
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
342
343
344
        return cls.from_dict(config_dict, **kwargs)

    @classmethod
345
    def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
346
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
347
348
        From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
        :class:`~transformers.PretrainedConfig` using ``from_dict``.
349
350

        Parameters:
351
            pretrained_model_name_or_path (:obj:`str`):
Lysandre's avatar
Lysandre committed
352
353
354
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

        Returns:
355
            :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
Lysandre's avatar
Lysandre committed
356

357
        """
358
359
360
361
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
362
        local_files_only = kwargs.pop("local_files_only", False)
363

364
        if os.path.isdir(pretrained_model_name_or_path):
365
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
366
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
367
            config_file = pretrained_model_name_or_path
368
        else:
369
370
371
            config_file = hf_bucket_url(
                pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
            )
372

373
        try:
374
            # Load from URL or cache if already cached
375
376
377
378
379
380
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
381
                local_files_only=local_files_only,
382
            )
383
            # Load config dict
384
385
            if resolved_config_file is None:
                raise EnvironmentError
386
            config_dict = cls._dict_from_json_file(resolved_config_file)
387

thomwolf's avatar
thomwolf committed
388
        except EnvironmentError:
389
390
391
392
393
            msg = (
                f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
            )
thomwolf's avatar
thomwolf committed
394
395
            raise EnvironmentError(msg)

396
        except json.JSONDecodeError:
397
398
399
400
401
            msg = (
                "Couldn't reach server at '{}' to download configuration file or "
                "configuration file is not a valid JSON file. "
                "Please check network or file content here: {}.".format(config_file, resolved_config_file)
            )
402
403
            raise EnvironmentError(msg)

404
405
406
        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
407
            logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
408

409
410
411
        return config_dict, kwargs

    @classmethod
412
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
413
        """
414
        Instantiates a :class:`~transformers.PretrainedConfig` from a Python dictionary of parameters.
Lysandre's avatar
Lysandre committed
415
416

        Args:
417
418
419
420
421
            config_dict (:obj:`Dict[str, Any]`):
                Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
                retrieved from a pretrained checkpoint by leveraging the
                :func:`~transformers.PretrainedConfig.get_config_dict` method.
            kwargs (:obj:`Dict[str, Any]`):
Lysandre's avatar
Lysandre committed
422
423
424
                Additional parameters from which to initialize the configuration object.

        Returns:
425
            :class:`PretrainedConfig`: The configuration object instantiated from those parameters.
Lysandre's avatar
Lysandre committed
426
        """
427
428
429
430
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        config = cls(**config_dict)

431
        if hasattr(config, "pruned_heads"):
432
            config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
433
434
435
436
437
438
439
440
441
442

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

443
        logger.info("Model config %s", str(config))
444
445
446
447
448
449
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config

    @classmethod
Lysandre's avatar
Lysandre committed
450
    def from_json_file(cls, json_file: str) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
451
        """
452
        Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
Lysandre's avatar
Lysandre committed
453
454

        Args:
455
            json_file (:obj:`str`):
Lysandre's avatar
Lysandre committed
456
457
458
                Path to the JSON file containing the parameters.

        Returns:
459
            :class:`PretrainedConfig`: The configuration object instantiated from that JSON file.
Lysandre's avatar
Lysandre committed
460
461

        """
462
463
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)
464
465

    @classmethod
466
    def _dict_from_json_file(cls, json_file: str):
467
        with open(json_file, "r", encoding="utf-8") as reader:
468
            text = reader.read()
469
        return json.loads(text)
470
471
472
473
474

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __repr__(self):
475
        return "{} {}".format(self.__class__.__name__, self.to_json_string())
476

477
    def to_diff_dict(self) -> Dict[str, Any]:
478
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
479
480
        Removes all attributes from config which correspond to the default config attributes for better readability and
        serializes to a Python dictionary.
481
482

        Returns:
483
            :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
484
485
486
487
488
489
        """
        config_dict = self.to_dict()

        # get the default config dict
        default_config_dict = PretrainedConfig().to_dict()

490
491
492
        # get class specific config dict
        class_config_dict = self.__class__().to_dict() if not self.is_composition else {}

493
494
495
496
        serializable_config_dict = {}

        # only serialize values that differ from the default config
        for key, value in config_dict.items():
497
498
499
500
501
            if (
                key not in default_config_dict
                or value != default_config_dict[key]
                or (key in class_config_dict and value != class_config_dict[key])
            ):
502
503
504
505
                serializable_config_dict[key] = value

        return serializable_config_dict

506
    def to_dict(self) -> Dict[str, Any]:
Lysandre's avatar
Lysandre committed
507
508
509
510
        """
        Serializes this instance to a Python dictionary.

        Returns:
511
            :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
Lysandre's avatar
Lysandre committed
512
        """
513
        output = copy.deepcopy(self.__dict__)
514
515
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
516
517
        return output

518
    def to_json_string(self, use_diff: bool = True) -> str:
Lysandre's avatar
Lysandre committed
519
520
521
        """
        Serializes this instance to a JSON string.

522
        Args:
523
524
525
            use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
                If set to ``True``, only the difference between the config instance and the default
                ``PretrainedConfig()`` is serialized to JSON string.
526

Lysandre's avatar
Lysandre committed
527
        Returns:
528
            :obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
Lysandre's avatar
Lysandre committed
529
        """
530
531
532
533
534
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
535

536
    def to_json_file(self, json_file_path: str, use_diff: bool = True):
Lysandre's avatar
Lysandre committed
537
        """
538
        Save this instance to a JSON file.
Lysandre's avatar
Lysandre committed
539
540

        Args:
541
            json_file_path (:obj:`str`):
Lysandre's avatar
Lysandre committed
542
                Path to the JSON file in which this configuration instance's parameters will be saved.
543
544
545
            use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
                If set to ``True``, only the difference between the config instance and the default
                ``PretrainedConfig()`` is serialized to JSON file.
Lysandre's avatar
Lysandre committed
546
        """
547
        with open(json_file_path, "w", encoding="utf-8") as writer:
548
            writer.write(self.to_json_string(use_diff=use_diff))
549

550
    def update(self, config_dict: Dict[str, Any]):
551
        """
552
        Updates attributes of this class with attributes from ``config_dict``.
553
554

        Args:
555
            config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
556
557
558
        """
        for key, value in config_dict.items():
            setattr(self, key, value)