"scripts/playground/vscode:/vscode.git/clone" did not exist on "37963394aa28769abb1843d4373ae799d4e93f07"
configuration_utils.py 31.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""


import copy
import json
import os
22
from typing import Any, Dict, Tuple, Union
23

24
from . import __version__
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
Lysandre Debut's avatar
Lysandre Debut committed
26
from .utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
27

28

Lysandre Debut's avatar
Lysandre Debut committed
29
logger = logging.get_logger(__name__)
30

31

32
class PretrainedConfig(object):
Sylvain Gugger's avatar
Sylvain Gugger committed
33
34
35
    r"""
    Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
    methods for loading/downloading/saving configurations.
Lysandre's avatar
Lysandre committed
36

Sylvain Gugger's avatar
Sylvain Gugger committed
37
38
    Note: A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
    initialize a model does **not** load the model weights. It only affects the model's configuration.
Lysandre's avatar
Lysandre committed
39
40

    Class attributes (overridden by derived classes)
Sylvain Gugger's avatar
Sylvain Gugger committed
41

Lysandre's avatar
Lysandre committed
42
43
        - **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
          recreate the correct object in :class:`~transformers.AutoConfig`.
Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
46
        - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case
          the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig`
          like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`.
47
48
        - **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at
          dictionary outputs of the model during inference.
Lysandre's avatar
Lysandre committed
49
50

    Args:
51
        name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
Sylvain Gugger's avatar
Sylvain Gugger committed
52
53
54
            Store the string that was passed to :func:`~transformers.PreTrainedModel.from_pretrained` or
            :func:`~transformers.TFPreTrainedModel.from_pretrained` as ``pretrained_model_name_or_path`` if the
            configuration was created with such a method.
Lysandre's avatar
Lysandre committed
55
56
57
58
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the model should return all hidden-states.
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the model should returns all attentions.
59
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
60
61
            Whether or not the model should return a :class:`~transformers.file_utils.ModelOutput` instead of a plain
            tuple.
Lysandre's avatar
Lysandre committed
62
63
64
65
66
        is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether the model is used as an encoder/decoder or not.
        is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether the model is used as decoder or not (in which case it's used as an encoder).
        add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
67
68
69
            Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
            that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which
            consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
Lysandre's avatar
Lysandre committed
70
        tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
Sylvain Gugger's avatar
Sylvain Gugger committed
71
72
            Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
            and decoder model to have the exact same parameter names.
Lysandre's avatar
Lysandre committed
73
        prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Sylvain Gugger's avatar
Sylvain Gugger committed
74
75
            Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
            heads to prune in said layer.
Lysandre's avatar
Lysandre committed
76

Sylvain Gugger's avatar
Sylvain Gugger committed
77
            For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
Lysandre's avatar
Lysandre committed
78
79
80
        xla_device (:obj:`bool`, `optional`):
            A flag to indicate if TPU are available or not.
        chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
Sylvain Gugger's avatar
Sylvain Gugger committed
81
82
83
84
            The chunk size of all feed forward layers in the residual attention blocks. A chunk size of :obj:`0` means
            that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes
            :obj:`n` < sequence_length embeddings at a time. For more information on feed forward chunking, see `How
            does Feed Forward Chunking work? <../glossary.html#feed-forward-chunking>`__ .
Lysandre's avatar
Lysandre committed
85
86

    Parameters for sequence generation
Sylvain Gugger's avatar
Sylvain Gugger committed
87
88
89
90
91
92
93
94
95
96
97
98

        - **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by default in the
          :obj:`generate` method of the model.
        - **min_length** (:obj:`int`, `optional`, defaults to 10) -- Minimum length that will be used by default in the
          :obj:`generate` method of the model.
        - **do_sample** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by default in the
          :obj:`generate` method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
        - **early_stopping** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by default
          in the :obj:`generate` method of the model. Whether to stop the beam search when at least ``num_beams``
          sentences are finished per batch or not.
        - **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by
          default in the :obj:`generate` method of the model. 1 means no beam search.
99
100
101
102
103
104
        - **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams`
          into in order to ensure diversity among different groups of beams that will be used by default in the
          :obj:`generate` method of the model. 1 means no group beam search.
        - **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group
          beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity
          penalty. The higher the penalty, the more diverse are the outputs.
Lysandre's avatar
Lysandre committed
105
106
107
        - **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
          probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
          positive.
Sylvain Gugger's avatar
Sylvain Gugger committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        - **top_k** (:obj:`int`, `optional`, defaults to 50) -- Number of highest probability vocabulary tokens to keep
          for top-k-filtering that will be used by default in the :obj:`generate` method of the model.
        - **top_p** (:obj:`float`, `optional`, defaults to 1) -- Value that will be used by default in the
          :obj:`generate` method of the model for ``top_p``. If set to float < 1, only the most probable tokens with
          probabilities that add up to ``top_p`` or higher are kept for generation.
        - **repetition_penalty** (:obj:`float`, `optional`, defaults to 1) -- Parameter for repetition penalty that
          will be used by default in the :obj:`generate` method of the model. 1.0 means no penalty.
        - **length_penalty** (:obj:`float`, `optional`, defaults to 1) -- Exponential penalty to the length that will
          be used by default in the :obj:`generate` method of the model.
        - **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the
          :obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size
          can only occur once.
        - **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated
          that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the
          words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word,
          add_prefix_space=True)`.
        - **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned
          sequences for each element in the batch that will be used by default in the :obj:`generate` method of the
          model.
127
128
129
130
131
        - **output_scores** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should return the
          logits when used for generation
        - **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
          return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`

Lysandre's avatar
Lysandre committed
132
133

    Parameters for fine-tuning tasks
Sylvain Gugger's avatar
Sylvain Gugger committed
134
135
136

        - **architectures** (:obj:`List[str]`, `optional`) -- Model architectures that can be used with the model
          pretrained weights.
Lysandre's avatar
Lysandre committed
137
138
        - **finetuning_task** (:obj:`str`, `optional`) -- Name of the task used to fine-tune the model. This can be
          used when converting from an original (TensorFlow or PyTorch) checkpoint.
Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
        - **id2label** (:obj:`Dict[int, str]`, `optional`) -- A map from index (for instance prediction index, or
          target index) to label.
Lysandre's avatar
Lysandre committed
141
142
143
        - **label2id** (:obj:`Dict[str, int]`, `optional`) -- A map from label to index for the model.
        - **num_labels** (:obj:`int`, `optional`) -- Number of labels to use in the last layer added to the model,
          typically for a classification task.
Sylvain Gugger's avatar
Sylvain Gugger committed
144
145
        - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
          current task.
Lysandre's avatar
Lysandre committed
146
147

    Parameters linked to the tokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
148

149
150
        - **tokenizer_class** (:obj:`str`, `optional`) -- The name of the associated tokenizer class to use (if none is
          set, will use the tokenizer associated to the model by default).
Sylvain Gugger's avatar
Sylvain Gugger committed
151
152
        - **prefix** (:obj:`str`, `optional`) -- A specific prompt that should be added at the beginning of each text
          before calling the model.
Lysandre's avatar
Lysandre committed
153
154
155
        - **bos_token_id** (:obj:`int`, `optional`)) -- The id of the `beginning-of-stream` token.
        - **pad_token_id** (:obj:`int`, `optional`)) -- The id of the `padding` token.
        - **eos_token_id** (:obj:`int`, `optional`)) -- The id of the `end-of-stream` token.
Sylvain Gugger's avatar
Sylvain Gugger committed
156
157
        - **decoder_start_token_id** (:obj:`int`, `optional`)) -- If an encoder-decoder model starts decoding with a
          different token than `bos`, the id of that token.
158
        - **sep_token_id** (:obj:`int`, `optional`)) -- The id of the `separation` token.
Lysandre's avatar
Lysandre committed
159
160

    PyTorch specific parameters
Sylvain Gugger's avatar
Sylvain Gugger committed
161

Lysandre's avatar
Lysandre committed
162
163
        - **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
          used with Torchscript.
Sylvain Gugger's avatar
Sylvain Gugger committed
164
165
166
        - **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
          output word embeddings should be tied. Note that this is only relevant if the model has a output word
          embedding layer.
Lysandre's avatar
Lysandre committed
167
168

    TensorFlow specific parameters
Sylvain Gugger's avatar
Sylvain Gugger committed
169
170
171

        - **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
          BFloat16 scalars (only used by some TensorFlow models).
172
    """
173
    model_type: str = ""
174
    is_composition: bool = False
175
176

    def __init__(self, **kwargs):
thomwolf's avatar
thomwolf committed
177
        # Attributes with defaults
178
        self.return_dict = kwargs.pop("return_dict", True)
179
        self.output_hidden_states = kwargs.pop("output_hidden_states", False)
180
        self.output_attentions = kwargs.pop("output_attentions", False)
181
182
183
        self.torchscript = kwargs.pop("torchscript", False)  # Only used by PyTorch models
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
        self.pruned_heads = kwargs.pop("pruned_heads", {})
184
185
186
        self.tie_word_embeddings = kwargs.pop(
            "tie_word_embeddings", True
        )  # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
thomwolf's avatar
thomwolf committed
187
188

        # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
Patrick von Platen's avatar
Patrick von Platen committed
189
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
190
        self.is_decoder = kwargs.pop("is_decoder", False)
191
        self.add_cross_attention = kwargs.pop("add_cross_attention", False)
192
        self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
193

thomwolf's avatar
thomwolf committed
194
        # Parameters for sequence generation
195
        self.max_length = kwargs.pop("max_length", 20)
Patrick von Platen's avatar
Patrick von Platen committed
196
        self.min_length = kwargs.pop("min_length", 0)
197
        self.do_sample = kwargs.pop("do_sample", False)
Patrick von Platen's avatar
Patrick von Platen committed
198
        self.early_stopping = kwargs.pop("early_stopping", False)
199
        self.num_beams = kwargs.pop("num_beams", 1)
200
201
        self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
        self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
202
203
204
205
206
        self.temperature = kwargs.pop("temperature", 1.0)
        self.top_k = kwargs.pop("top_k", 50)
        self.top_p = kwargs.pop("top_p", 1.0)
        self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        self.length_penalty = kwargs.pop("length_penalty", 1.0)
Patrick von Platen's avatar
Patrick von Platen committed
207
        self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
208
        self.bad_words_ids = kwargs.pop("bad_words_ids", None)
209
        self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Pradhy729's avatar
Pradhy729 committed
210
        self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
211
212
        self.output_scores = kwargs.pop("output_scores", False)
        self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
thomwolf's avatar
thomwolf committed
213

thomwolf's avatar
thomwolf committed
214
        # Fine-tuning task arguments
Julien Chaumond's avatar
Julien Chaumond committed
215
        self.architectures = kwargs.pop("architectures", None)
216
        self.finetuning_task = kwargs.pop("finetuning_task", None)
217
218
219
        self.id2label = kwargs.pop("id2label", None)
        self.label2id = kwargs.pop("label2id", None)
        if self.id2label is not None:
220
            kwargs.pop("num_labels", None)
221
222
223
224
            self.id2label = dict((int(key), value) for key, value in self.id2label.items())
            # Keys are always strings in JSON so convert ids to int here.
        else:
            self.num_labels = kwargs.pop("num_labels", 2)
thomwolf's avatar
thomwolf committed
225

226
        # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
227
        self.tokenizer_class = kwargs.pop("tokenizer_class", None)
228
229
230
231
        self.prefix = kwargs.pop("prefix", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
232
233
        self.sep_token_id = kwargs.pop("sep_token_id", None)

234
235
236
237
238
        self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

        # task specific arguments
        self.task_specific_params = kwargs.pop("task_specific_params", None)

239
240
241
        # TPU arguments
        self.xla_device = kwargs.pop("xla_device", None)

242
243
244
        # Name or path to the pretrained checkpoint
        self._name_or_path = str(kwargs.pop("name_or_path", ""))

245
246
247
        # Drop the transformers version info
        kwargs.pop("transformers_version", None)

thomwolf's avatar
thomwolf committed
248
249
250
251
252
253
254
255
        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(key, value, self))
                raise err

256
257
258
259
260
261
262
263
    @property
    def name_or_path(self) -> str:
        return self._name_or_path

    @name_or_path.setter
    def name_or_path(self, value):
        self._name_or_path = str(value)  # Make sure that name_or_path is a string (for JSON encoding)

264
    @property
265
    def use_return_dict(self) -> bool:
266
        """
267
        :obj:`bool`: Whether or not return :class:`~transformers.file_utils.ModelOutput` instead of tuples.
268
        """
269
270
        # If torchscript is set, force `return_dict=False` to avoid jit errors
        return self.return_dict and not self.torchscript
271

272
    @property
273
    def num_labels(self) -> int:
274
275
276
        """
        :obj:`int`: The number of labels for classification models.
        """
277
        return len(self.id2label)
278
279

    @num_labels.setter
280
    def num_labels(self, num_labels: int):
281
        self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
282
283
        self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))

284
    def save_pretrained(self, save_directory: Union[str, os.PathLike]):
Lysandre's avatar
Lysandre committed
285
        """
286
287
        Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
        :func:`~transformers.PretrainedConfig.from_pretrained` class method.
Lysandre's avatar
Lysandre committed
288
289

        Args:
290
            save_directory (:obj:`str` or :obj:`os.PathLike`):
291
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
292
        """
293
294
295
        if os.path.isfile(save_directory):
            raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
        os.makedirs(save_directory, exist_ok=True)
296
297
298
        # If we save using the predefined names, we can load using `from_pretrained`
        output_config_file = os.path.join(save_directory, CONFIG_NAME)

299
        self.to_json_file(output_config_file, use_diff=True)
thomwolf's avatar
thomwolf committed
300
        logger.info("Configuration saved in {}".format(output_config_file))
301
302

    @classmethod
303
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
304
        r"""
305
306
        Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
        configuration.
Lysandre's avatar
Lysandre committed
307
308

        Args:
309
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
310
311
                This can be either:

312
313
314
                - a string, the `model id` of a pretrained model configuration hosted inside a model repo on
                  huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
                  namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
315
316
317
318
                - a path to a `directory` containing a configuration file saved using the
                  :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
                - a path or url to a saved configuration JSON `file`, e.g.,
                  ``./my_model_directory/configuration.json``.
319
            cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
320
321
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
Lysandre's avatar
Lysandre committed
322
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
323
324
                Whether or not to force to (re-)download the configuration files and override the cached versions if
                they exist.
Lysandre's avatar
Lysandre committed
325
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
326
327
328
                Whether or not to delete incompletely received file. Attempts to resume the download if such a file
                exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
329
330
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
331
332
333
            use_auth_token (:obj:`str` or `bool`, `optional`):
                The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
                generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
Julien Chaumond's avatar
Julien Chaumond committed
334
335
336
337
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
338
339
340
341
342
343
344
345
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
            kwargs (:obj:`Dict[str, Any]`, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
346
347
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
                by the ``return_unused_kwargs`` keyword parameter.
348

349
350
351
352
        .. note::

            Passing :obj:`use_auth_token=True` is required when you want to use a private model.

353

Lysandre's avatar
Lysandre committed
354
        Returns:
355
            :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
356
357
358
359
360

        Examples::

            # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
            # derived class: BertConfig
361
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from huggingface.co and cache.
362
363
            config = BertConfig.from_pretrained('./test/saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
            config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
364
365
366
            config = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
            assert config.output_attentions == True
            config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True,
367
                                                               foo=False, return_unused_kwargs=True)
368
            assert config.output_attentions == True
369
370
371
            assert unused_kwargs == {'foo': False}

        """
Julien Chaumond's avatar
Julien Chaumond committed
372
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
373
374
375
        return cls.from_dict(config_dict, **kwargs)

    @classmethod
376
377
378
    def get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
379
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
380
381
        From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
        :class:`~transformers.PretrainedConfig` using ``from_dict``.
382

383
384


385
        Parameters:
386
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
Lysandre's avatar
Lysandre committed
387
388
389
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

        Returns:
390
            :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
Lysandre's avatar
Lysandre committed
391

392
        """
393
394
395
396
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
397
        use_auth_token = kwargs.pop("use_auth_token", None)
398
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
399
        revision = kwargs.pop("revision", None)
400

401
        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
402
        if os.path.isdir(pretrained_model_name_or_path):
403
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
404
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
405
            config_file = pretrained_model_name_or_path
406
        else:
407
            config_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
408
                pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
409
            )
410

411
        try:
412
            # Load from URL or cache if already cached
413
414
415
416
417
418
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
419
                local_files_only=local_files_only,
420
                use_auth_token=use_auth_token,
421
            )
422
423
            # Load config dict
            config_dict = cls._dict_from_json_file(resolved_config_file)
424

Julien Chaumond's avatar
Julien Chaumond committed
425
426
        except EnvironmentError as err:
            logger.error(err)
427
428
429
430
431
            msg = (
                f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
            )
thomwolf's avatar
thomwolf committed
432
433
            raise EnvironmentError(msg)

434
        except json.JSONDecodeError:
435
436
437
438
439
            msg = (
                "Couldn't reach server at '{}' to download configuration file or "
                "configuration file is not a valid JSON file. "
                "Please check network or file content here: {}.".format(config_file, resolved_config_file)
            )
440
441
            raise EnvironmentError(msg)

442
443
444
        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
445
            logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
446

447
448
449
        return config_dict, kwargs

    @classmethod
450
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
451
        """
452
        Instantiates a :class:`~transformers.PretrainedConfig` from a Python dictionary of parameters.
Lysandre's avatar
Lysandre committed
453
454

        Args:
455
456
457
458
459
            config_dict (:obj:`Dict[str, Any]`):
                Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
                retrieved from a pretrained checkpoint by leveraging the
                :func:`~transformers.PretrainedConfig.get_config_dict` method.
            kwargs (:obj:`Dict[str, Any]`):
Lysandre's avatar
Lysandre committed
460
461
462
                Additional parameters from which to initialize the configuration object.

        Returns:
463
            :class:`PretrainedConfig`: The configuration object instantiated from those parameters.
Lysandre's avatar
Lysandre committed
464
        """
465
466
467
468
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        config = cls(**config_dict)

469
        if hasattr(config, "pruned_heads"):
470
            config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
471
472
473
474
475
476
477
478
479
480

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

481
        logger.info("Model config %s", str(config))
482
483
484
485
486
487
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config

    @classmethod
488
    def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
489
        """
490
        Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
Lysandre's avatar
Lysandre committed
491
492

        Args:
493
            json_file (:obj:`str` or :obj:`os.PathLike`):
Lysandre's avatar
Lysandre committed
494
495
496
                Path to the JSON file containing the parameters.

        Returns:
497
            :class:`PretrainedConfig`: The configuration object instantiated from that JSON file.
Lysandre's avatar
Lysandre committed
498
499

        """
500
501
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)
502
503

    @classmethod
504
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
505
        with open(json_file, "r", encoding="utf-8") as reader:
506
            text = reader.read()
507
        return json.loads(text)
508
509
510
511
512

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __repr__(self):
513
        return "{} {}".format(self.__class__.__name__, self.to_json_string())
514

515
    def to_diff_dict(self) -> Dict[str, Any]:
516
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
517
518
        Removes all attributes from config which correspond to the default config attributes for better readability and
        serializes to a Python dictionary.
519
520

        Returns:
521
            :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
522
523
524
525
526
527
        """
        config_dict = self.to_dict()

        # get the default config dict
        default_config_dict = PretrainedConfig().to_dict()

528
529
530
        # get class specific config dict
        class_config_dict = self.__class__().to_dict() if not self.is_composition else {}

531
532
533
534
        serializable_config_dict = {}

        # only serialize values that differ from the default config
        for key, value in config_dict.items():
535
536
            if (
                key not in default_config_dict
537
                or key == "transformers_version"
538
539
540
                or value != default_config_dict[key]
                or (key in class_config_dict and value != class_config_dict[key])
            ):
541
542
543
544
                serializable_config_dict[key] = value

        return serializable_config_dict

545
    def to_dict(self) -> Dict[str, Any]:
Lysandre's avatar
Lysandre committed
546
547
548
549
        """
        Serializes this instance to a Python dictionary.

        Returns:
550
            :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
Lysandre's avatar
Lysandre committed
551
        """
552
        output = copy.deepcopy(self.__dict__)
553
554
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
555
556
557
558

        # Transformers version when serializing the model
        output["transformers_version"] = __version__

559
560
        return output

561
    def to_json_string(self, use_diff: bool = True) -> str:
Lysandre's avatar
Lysandre committed
562
563
564
        """
        Serializes this instance to a JSON string.

565
        Args:
566
567
568
            use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
                If set to ``True``, only the difference between the config instance and the default
                ``PretrainedConfig()`` is serialized to JSON string.
569

Lysandre's avatar
Lysandre committed
570
        Returns:
571
            :obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
Lysandre's avatar
Lysandre committed
572
        """
573
574
575
576
577
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
578

579
    def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
Lysandre's avatar
Lysandre committed
580
        """
581
        Save this instance to a JSON file.
Lysandre's avatar
Lysandre committed
582
583

        Args:
584
            json_file_path (:obj:`str` or :obj:`os.PathLike`):
Lysandre's avatar
Lysandre committed
585
                Path to the JSON file in which this configuration instance's parameters will be saved.
586
587
588
            use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
                If set to ``True``, only the difference between the config instance and the default
                ``PretrainedConfig()`` is serialized to JSON file.
Lysandre's avatar
Lysandre committed
589
        """
590
        with open(json_file_path, "w", encoding="utf-8") as writer:
591
            writer.write(self.to_json_string(use_diff=use_diff))
592

593
    def update(self, config_dict: Dict[str, Any]):
594
        """
595
        Updates attributes of this class with attributes from ``config_dict``.
596
597

        Args:
598
            config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
599
600
601
        """
        for key, value in config_dict.items():
            setattr(self, key, value)