configuration_utils.py 43.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""


import copy
import json
import os
22
import re
23
import warnings
24
25
26
from typing import Any, Dict, Optional, Tuple, Union

from packaging import version
27

28
from . import __version__
29
30
31
32
33
from .file_utils import (
    CONFIG_NAME,
    PushToHubMixin,
    cached_path,
    copy_func,
34
    get_list_of_files,
35
36
37
    hf_bucket_url,
    is_offline_mode,
    is_remote_url,
38
    is_torch_available,
39
)
Lysandre Debut's avatar
Lysandre Debut committed
40
from .utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
41

42

Lysandre Debut's avatar
Lysandre Debut committed
43
logger = logging.get_logger(__name__)
44
45
FULL_CONFIGURATION_FILE = "config.json"
_re_configuration_file = re.compile(r"config\.(.*)\.json")
46

47

Sylvain Gugger's avatar
Sylvain Gugger committed
48
class PretrainedConfig(PushToHubMixin):
Sylvain Gugger's avatar
Sylvain Gugger committed
49
50
51
    r"""
    Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
    methods for loading/downloading/saving configurations.
Lysandre's avatar
Lysandre committed
52

Sylvain Gugger's avatar
Sylvain Gugger committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    <Tip>

    A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
    initialize a model does **not** load the model weights. It only affects the model's configuration.

    </Tip>

    Class attributes (overridden by derived classes):

    - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
      the correct object in [`~transformers.AutoConfig`].
    - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
      config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
      [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
    - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
      outputs of the model during inference.
    - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
      naming of attributes.

    Common attributes (present in all subclasses):

    - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
      embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
    - **hidden_size** (`int`) -- The hidden size of the model.
    - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
      model.
    - **num_hidden_layers** (`int`) -- The number of blocks in the model.

    Arg:
82
        name_or_path (`str`, *optional*, defaults to `""`):
Sylvain Gugger's avatar
Sylvain Gugger committed
83
84
85
            Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
            [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
            with such a method.
86
        output_hidden_states (`bool`, *optional*, defaults to `False`):
Lysandre's avatar
Lysandre committed
87
            Whether or not the model should return all hidden-states.
88
        output_attentions (`bool`, *optional*, defaults to `False`):
Lysandre's avatar
Lysandre committed
89
            Whether or not the model should returns all attentions.
90
        return_dict (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
91
            Whether or not the model should return a [`~transformers.file_utils.ModelOutput`] instead of a plain tuple.
92
        is_encoder_decoder (`bool`, *optional*, defaults to `False`):
Lysandre's avatar
Lysandre committed
93
            Whether the model is used as an encoder/decoder or not.
94
        is_decoder (`bool`, *optional*, defaults to `False`):
Lysandre's avatar
Lysandre committed
95
            Whether the model is used as decoder or not (in which case it's used as an encoder).
96
        cross_attention_hidden_size** (`bool`, *optional*):
97
98
            The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
            setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
99
        add_cross_attention (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
100
            Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
Sylvain Gugger's avatar
Sylvain Gugger committed
101
102
            that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
            in `AUTO_MODELS_FOR_CAUSAL_LM`.
103
        tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
104
105
            Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
            and decoder model to have the exact same parameter names.
106
        prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
Sylvain Gugger's avatar
Sylvain Gugger committed
107
108
            Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
            heads to prune in said layer.
Lysandre's avatar
Lysandre committed
109

Sylvain Gugger's avatar
Sylvain Gugger committed
110
            For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
111
        chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
Sylvain Gugger's avatar
Sylvain Gugger committed
112
113
114
115
116
117
118
            The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
            the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
            sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
            Forward Chunking work?](../glossary.html#feed-forward-chunking).

        > Parameters for sequence generation

119
        max_length (`int`, *optional*, defaults to 20):
Sylvain Gugger's avatar
Sylvain Gugger committed
120
            Maximum length that will be used by default in the `generate` method of the model.
121
        min_length (`int`, *optional*, defaults to 10):
Sylvain Gugger's avatar
Sylvain Gugger committed
122
            Minimum length that will be used by default in the `generate` method of the model.
123
        do_sample (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
124
125
            Flag that will be used by default in the `generate` method of the model. Whether or not to use sampling ;
            use greedy decoding otherwise.
126
        early_stopping (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
127
128
            Flag that will be used by default in the `generate` method of the model. Whether to stop the beam search
            when at least `num_beams` sentences are finished per batch or not.
129
        num_beams (`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
130
131
            Number of beams for beam search that will be used by default in the `generate` method of the model. 1 means
            no beam search.
132
        num_beam_groups (`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
133
134
            Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams
            that will be used by default in the `generate` method of the model. 1 means no group beam search.
135
        diversity_penalty (`float`, *optional*, defaults to 0.0):
Sylvain Gugger's avatar
Sylvain Gugger committed
136
137
            Value to control diversity for group beam search. that will be used by default in the `generate` method of
            the model. 0 means no diversity penalty. The higher the penalty, the more diverse are the outputs.
138
        temperature (`float`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
            The value used to module the next token probabilities that will be used by default in the `generate` method
            of the model. Must be strictly positive.
141
        top_k (`int`, *optional*, defaults to 50):
Sylvain Gugger's avatar
Sylvain Gugger committed
142
143
            Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in
            the `generate` method of the model.
144
        top_p (`float`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
145
146
            Value that will be used by default in the `generate` method of the model for `top_p`. If set to float < 1,
            only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
147
        repetition_penalty (`float`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
148
149
            Parameter for repetition penalty that will be used by default in the `generate` method of the model. 1.0
            means no penalty.
150
        length_penalty (`float`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
151
            Exponential penalty to the length that will be used by default in the `generate` method of the model.
152
        no_repeat_ngram_size (`int`, *optional*, defaults to 0) -- Value that will be used by default in the
Sylvain Gugger's avatar
Sylvain Gugger committed
153
154
            `generate` method of the model for `no_repeat_ngram_size`. If set to int > 0, all ngrams of that size can
            only occur once.
155
        encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0) -- Value that will be used by
Sylvain Gugger's avatar
Sylvain Gugger committed
156
157
            default in the `generate` method of the model for `encoder_no_repeat_ngram_size`. If set to int > 0, all
            ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
158
        bad_words_ids (`List[int]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
159
160
161
            List of token ids that are not allowed to be generated that will be used by default in the `generate`
            method of the model. In order to get the tokens of the words that should not appear in the generated text,
            use `tokenizer.encode(bad_word, add_prefix_space=True)`.
162
        num_return_sequences (`int`, *optional*, defaults to 1):
Sylvain Gugger's avatar
Sylvain Gugger committed
163
164
            Number of independently computed returned sequences for each element in the batch that will be used by
            default in the `generate` method of the model.
165
        output_scores (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
166
            Whether the model should return the logits when used for generation.
167
        return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
168
            Whether the model should return a [`~transformers.file_utils.ModelOutput`] instead of a `torch.LongTensor`.
169
        forced_bos_token_id (`int`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
170
171
172
            The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
            multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
            language token.
173
        forced_eos_token_id (`int`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
174
            The id of the token to force as the last generated token when `max_length` is reached.
175
        remove_invalid_values (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
176
177
178
179
180
            Whether to remove possible _nan_ and _inf_ outputs of the model to prevent the generation method to crash.
            Note that using `remove_invalid_values` can slow down generation.

        > Parameters for fine-tuning tasks

Sylvain Gugger's avatar
Sylvain Gugger committed
181
182
        architectures (`List[str]`, *optional*):
            Model architectures that can be used with the model pretrained weights.
183
        finetuning_task (`str`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
            Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
            or PyTorch) checkpoint.
186
        id2label (`Dict[int, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
187
            A map from index (for instance prediction index, or target index) to label.
188
189
        label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
        num_labels (`int`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
190
            Number of labels to use in the last layer added to the model, typically for a classification task.
191
        task_specific_params (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
192
            Additional keyword arguments to store for the current task.
193
        problem_type (`str`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
194
195
196
197
198
            Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
            `"single_label_classification"` or `"multi_label_classification"`.

        > Parameters linked to the tokenizer

199
        tokenizer_class (`str`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
200
201
            The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
            model by default).
202
        prefix (`str`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
203
            A specific prompt that should be added at the beginning of each text before calling the model.
204
205
206
207
        bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
        pad_token_id (`int`, *optional*): The id of the _padding_ token.
        eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
        decoder_start_token_id (`int`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
208
            If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
209
        sep_token_id (`int`, *optional*): The id of the _separation_ token.
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
212

        > PyTorch specific parameters

213
        torchscript (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
214
            Whether or not the model should be used with Torchscript.
215
        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
            Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
            model has a output word embedding layer.
218
        torch_dtype (`str`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
219
220
221
222
223
224
225
226
227
228
229
230
            The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
            (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
            model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
            `float16` weights. Since the config object is stored in plain text, this attribute contains just the
            floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
            `"float16"` string.

            This attribute is currently not being used during model loading time, but this may change in the future
            versions. But we can already start preparing for the future by saving the dtype with save_pretrained.

        > TensorFlow specific parameters

231
        use_bfloat16 (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
232
            Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
233
    """
234
    model_type: str = ""
235
    is_composition: bool = False
236
237
238
239
240
241
242
243
244
245
246
    attribute_map: Dict[str, str] = {}

    def __setattr__(self, key, value):
        if key in super().__getattribute__("attribute_map"):
            key = super().__getattribute__("attribute_map")[key]
        super().__setattr__(key, value)

    def __getattribute__(self, key):
        if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
            key = super().__getattribute__("attribute_map")[key]
        return super().__getattribute__(key)
247
248

    def __init__(self, **kwargs):
thomwolf's avatar
thomwolf committed
249
        # Attributes with defaults
250
        self.return_dict = kwargs.pop("return_dict", True)
251
        self.output_hidden_states = kwargs.pop("output_hidden_states", False)
252
        self.output_attentions = kwargs.pop("output_attentions", False)
253
        self.torchscript = kwargs.pop("torchscript", False)  # Only used by PyTorch models
254
        self.torch_dtype = kwargs.pop("torch_dtype", None)  # Only used by PyTorch models
255
256
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
        self.pruned_heads = kwargs.pop("pruned_heads", {})
257
258
259
        self.tie_word_embeddings = kwargs.pop(
            "tie_word_embeddings", True
        )  # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
thomwolf's avatar
thomwolf committed
260
261

        # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
Patrick von Platen's avatar
Patrick von Platen committed
262
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
263
        self.is_decoder = kwargs.pop("is_decoder", False)
264
        self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
265
        self.add_cross_attention = kwargs.pop("add_cross_attention", False)
266
        self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
267

thomwolf's avatar
thomwolf committed
268
        # Parameters for sequence generation
269
        self.max_length = kwargs.pop("max_length", 20)
Patrick von Platen's avatar
Patrick von Platen committed
270
        self.min_length = kwargs.pop("min_length", 0)
271
        self.do_sample = kwargs.pop("do_sample", False)
Patrick von Platen's avatar
Patrick von Platen committed
272
        self.early_stopping = kwargs.pop("early_stopping", False)
273
        self.num_beams = kwargs.pop("num_beams", 1)
274
275
        self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
        self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
276
277
278
279
280
        self.temperature = kwargs.pop("temperature", 1.0)
        self.top_k = kwargs.pop("top_k", 50)
        self.top_p = kwargs.pop("top_p", 1.0)
        self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        self.length_penalty = kwargs.pop("length_penalty", 1.0)
Patrick von Platen's avatar
Patrick von Platen committed
281
        self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
282
        self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
283
        self.bad_words_ids = kwargs.pop("bad_words_ids", None)
284
        self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Pradhy729's avatar
Pradhy729 committed
285
        self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
286
287
        self.output_scores = kwargs.pop("output_scores", False)
        self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
288
289
        self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
        self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
290
        self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
thomwolf's avatar
thomwolf committed
291

thomwolf's avatar
thomwolf committed
292
        # Fine-tuning task arguments
Julien Chaumond's avatar
Julien Chaumond committed
293
        self.architectures = kwargs.pop("architectures", None)
294
        self.finetuning_task = kwargs.pop("finetuning_task", None)
295
296
297
        self.id2label = kwargs.pop("id2label", None)
        self.label2id = kwargs.pop("label2id", None)
        if self.id2label is not None:
298
            kwargs.pop("num_labels", None)
299
300
301
302
            self.id2label = dict((int(key), value) for key, value in self.id2label.items())
            # Keys are always strings in JSON so convert ids to int here.
        else:
            self.num_labels = kwargs.pop("num_labels", 2)
thomwolf's avatar
thomwolf committed
303

304
305
306
307
308
309
310
311
        if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
            # we will start using self.torch_dtype in v5, but to be consistent with
            # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
            if is_torch_available():
                import torch

                self.torch_dtype = getattr(torch, self.torch_dtype)

312
        # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
313
        self.tokenizer_class = kwargs.pop("tokenizer_class", None)
314
315
316
317
        self.prefix = kwargs.pop("prefix", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
318
319
        self.sep_token_id = kwargs.pop("sep_token_id", None)

320
321
322
323
324
        self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

        # task specific arguments
        self.task_specific_params = kwargs.pop("task_specific_params", None)

325
326
327
328
329
        # regression / multi-label classification
        self.problem_type = kwargs.pop("problem_type", None)
        allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
        if self.problem_type is not None and self.problem_type not in allowed_problem_types:
            raise ValueError(
330
                f"The config parameter `problem_type` was not understood: received {self.problem_type} "
331
332
333
                "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
            )

334
        # TPU arguments
335
        if kwargs.pop("xla_device", None) is not None:
336
            logger.warning(
337
338
339
                "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
                "safely remove it from your `config.json` file."
            )
340

341
342
343
        # Name or path to the pretrained checkpoint
        self._name_or_path = str(kwargs.pop("name_or_path", ""))

344
        # Drop the transformers version info
345
        self.transformers_version = kwargs.pop("transformers_version", None)
346

347
        # Deal with gradient checkpointing
348
        if kwargs.get("gradient_checkpointing", False):
349
350
351
352
353
354
            warnings.warn(
                "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
                "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
                "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
            )

thomwolf's avatar
thomwolf committed
355
356
357
358
359
        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
360
                logger.error(f"Can't set {key} with value {value} for {self}")
thomwolf's avatar
thomwolf committed
361
362
                raise err

363
364
365
366
367
368
369
370
    @property
    def name_or_path(self) -> str:
        return self._name_or_path

    @name_or_path.setter
    def name_or_path(self, value):
        self._name_or_path = str(value)  # Make sure that name_or_path is a string (for JSON encoding)

371
    @property
372
    def use_return_dict(self) -> bool:
373
        """
374
        `bool`: Whether or not return [`~file_utils.ModelOutput`] instead of tuples.
375
        """
376
377
        # If torchscript is set, force `return_dict=False` to avoid jit errors
        return self.return_dict and not self.torchscript
378

379
    @property
380
    def num_labels(self) -> int:
381
        """
382
        `int`: The number of labels for classification models.
383
        """
384
        return len(self.id2label)
385
386

    @num_labels.setter
387
    def num_labels(self, num_labels: int):
388
        if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
389
            self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
390
            self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
391

Sylvain Gugger's avatar
Sylvain Gugger committed
392
    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
Lysandre's avatar
Lysandre committed
393
        """
394
395
        Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
        [`~PretrainedConfig.from_pretrained`] class method.
Lysandre's avatar
Lysandre committed
396
397

        Args:
398
            save_directory (`str` or `os.PathLike`):
399
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
400
            push_to_hub (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
401
                Whether or not to push your model to the Hugging Face model hub after saving it.
402

403
                <Tip warning={true}>
404

Sylvain Gugger's avatar
Sylvain Gugger committed
405
406
407
                Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
                which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
                folder. Pass along `temp_dir=True` to use a temporary directory instead.
408
409

                </Tip>
410

Sylvain Gugger's avatar
Sylvain Gugger committed
411
            kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
412
                Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
413
        """
414
        if os.path.isfile(save_directory):
415
            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
416
417
418
419
420

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

421
        os.makedirs(save_directory, exist_ok=True)
422
423
424
        # If we save using the predefined names, we can load using `from_pretrained`
        output_config_file = os.path.join(save_directory, CONFIG_NAME)

425
        self.to_json_file(output_config_file, use_diff=True)
426
        logger.info(f"Configuration saved in {output_config_file}")
427

Sylvain Gugger's avatar
Sylvain Gugger committed
428
        if push_to_hub:
429
            url = self._push_to_hub(repo, commit_message=commit_message)
Sylvain Gugger's avatar
Sylvain Gugger committed
430
431
            logger.info(f"Configuration pushed to the hub in this commit: {url}")

432
    @classmethod
433
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
434
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
435
        Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
Lysandre's avatar
Lysandre committed
436
437

        Args:
438
            pretrained_model_name_or_path (`str` or `os.PathLike`):
439
440
                This can be either:

441
442
443
444
445
                - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
                  huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
                  namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
                - a path to a *directory* containing a configuration file saved using the
                  [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
Sylvain Gugger's avatar
Sylvain Gugger committed
446
                - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
447
            cache_dir (`str` or `os.PathLike`, *optional*):
448
449
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
450
            force_download (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
451
452
                Whether or not to force to (re-)download the configuration files and override the cached versions if
                they exist.
453
            resume_download (`bool`, *optional*, defaults to `False`):
454
455
                Whether or not to delete incompletely received file. Attempts to resume the download if such a file
                exists.
456
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
457
458
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
459
            use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
460
461
                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
                when running `transformers-cli login` (stored in `~/.huggingface`).
462
            revision(`str`, *optional*, defaults to `"main"`):
Julien Chaumond's avatar
Julien Chaumond committed
463
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
464
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
Julien Chaumond's avatar
Julien Chaumond committed
465
                identifier allowed by git.
466
467
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                If `False`, then this function returns just the final configuration object.
468

Sylvain Gugger's avatar
Sylvain Gugger committed
469
470
471
                If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
                dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
                part of `kwargs` which has not been used to update `config` and is otherwise ignored.
472
            kwargs (`Dict[str, Any]`, *optional*):
473
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
Sylvain Gugger's avatar
Sylvain Gugger committed
474
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
475
                by the `return_unused_kwargs` keyword parameter.
476

477
        <Tip>
478

479
        Passing `use_auth_token=True` is required when you want to use a private model.
480

481
        </Tip>
482

Lysandre's avatar
Lysandre committed
483
        Returns:
484
485
486
487
488
489
490
            [`PretrainedConfig`]: The configuration object instantiated from this pretrained model.

        Examples:

        ```python
        # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
        # derived class: BertConfig
Sylvain Gugger's avatar
Sylvain Gugger committed
491
492
493
494
495
496
497
498
        config = BertConfig.from_pretrained(
            "bert-base-uncased"
        )  # Download configuration from huggingface.co and cache.
        config = BertConfig.from_pretrained(
            "./test/saved_model/"
        )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
        config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
        config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
499
        assert config.output_attentions == True
Sylvain Gugger's avatar
Sylvain Gugger committed
500
501
502
        config, unused_kwargs = BertConfig.from_pretrained(
            "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
        )
503
        assert config.output_attentions == True
Sylvain Gugger's avatar
Sylvain Gugger committed
504
        assert unused_kwargs == {"foo": False}
505
        ```"""
Julien Chaumond's avatar
Julien Chaumond committed
506
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
507
508
509
510
511
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warn(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )
512

513
514
515
        return cls.from_dict(config_dict, **kwargs)

    @classmethod
516
517
518
    def get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
519
        """
520
521
        From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
        [`PretrainedConfig`] using `from_dict`.
522

523
524


525
        Parameters:
526
            pretrained_model_name_or_path (`str` or `os.PathLike`):
Lysandre's avatar
Lysandre committed
527
528
529
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

        Returns:
530
            `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
Lysandre's avatar
Lysandre committed
531

532
        """
533
534
535
536
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
537
        use_auth_token = kwargs.pop("use_auth_token", None)
538
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
539
        revision = kwargs.pop("revision", None)
540
541
542
543
544
545
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline
546

547
548
549
550
        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

551
        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
552
        if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
553
            config_file = pretrained_model_name_or_path
554
        else:
555
556
557
558
559
            configuration_file = get_configuration_file(
                pretrained_model_name_or_path,
                revision=revision,
                use_auth_token=use_auth_token,
                local_files_only=local_files_only,
560
            )
561

562
563
564
565
566
567
568
            if os.path.isdir(pretrained_model_name_or_path):
                config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
            else:
                config_file = hf_bucket_url(
                    pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
                )

569
        try:
570
            # Load from URL or cache if already cached
571
572
573
574
575
576
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
577
                local_files_only=local_files_only,
578
                use_auth_token=use_auth_token,
579
                user_agent=user_agent,
580
            )
581
582
            # Load config dict
            config_dict = cls._dict_from_json_file(resolved_config_file)
583

Julien Chaumond's avatar
Julien Chaumond committed
584
585
        except EnvironmentError as err:
            logger.error(err)
586
587
            msg = (
                f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
588
589
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
                f"  (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
590
591
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
            )
592
593
594
595

            if revision is not None:
                msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"

thomwolf's avatar
thomwolf committed
596
597
            raise EnvironmentError(msg)

598
        except (json.JSONDecodeError, UnicodeDecodeError):
599
            msg = (
600
                f"Couldn't reach server at '{config_file}' to download configuration file or "
601
                "configuration file is not a valid JSON file. "
602
                f"Please check network or file content here: {resolved_config_file}."
603
            )
604
605
            raise EnvironmentError(msg)

606
        if resolved_config_file == config_file:
607
            logger.info(f"loading configuration file {config_file}")
608
        else:
609
            logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
610

611
612
613
        return config_dict, kwargs

    @classmethod
614
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
615
        """
616
        Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
Lysandre's avatar
Lysandre committed
617
618

        Args:
619
            config_dict (`Dict[str, Any]`):
620
                Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
Sylvain Gugger's avatar
Sylvain Gugger committed
621
                retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
622
            kwargs (`Dict[str, Any]`):
Lysandre's avatar
Lysandre committed
623
624
625
                Additional parameters from which to initialize the configuration object.

        Returns:
626
            [`PretrainedConfig`]: The configuration object instantiated from those parameters.
Lysandre's avatar
Lysandre committed
627
        """
628
629
630
631
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        config = cls(**config_dict)

632
        if hasattr(config, "pruned_heads"):
633
            config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
634
635
636
637
638
639

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
640
641
                if key != "torch_dtype":
                    to_remove.append(key)
642
643
644
        for key in to_remove:
            kwargs.pop(key, None)

645
        logger.info(f"Model config {config}")
646
647
648
649
650
651
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config

    @classmethod
652
    def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
653
        """
654
        Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
Lysandre's avatar
Lysandre committed
655
656

        Args:
657
            json_file (`str` or `os.PathLike`):
Lysandre's avatar
Lysandre committed
658
659
660
                Path to the JSON file containing the parameters.

        Returns:
661
            [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
Lysandre's avatar
Lysandre committed
662
663

        """
664
665
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)
666
667

    @classmethod
668
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
669
        with open(json_file, "r", encoding="utf-8") as reader:
670
            text = reader.read()
671
        return json.loads(text)
672
673
674
675
676

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __repr__(self):
677
        return f"{self.__class__.__name__} {self.to_json_string()}"
678

679
    def to_diff_dict(self) -> Dict[str, Any]:
680
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
681
682
        Removes all attributes from config which correspond to the default config attributes for better readability and
        serializes to a Python dictionary.
683
684

        Returns:
685
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
686
687
688
689
690
691
        """
        config_dict = self.to_dict()

        # get the default config dict
        default_config_dict = PretrainedConfig().to_dict()

692
693
694
        # get class specific config dict
        class_config_dict = self.__class__().to_dict() if not self.is_composition else {}

695
696
697
698
        serializable_config_dict = {}

        # only serialize values that differ from the default config
        for key, value in config_dict.items():
699
700
            if (
                key not in default_config_dict
701
                or key == "transformers_version"
702
703
704
                or value != default_config_dict[key]
                or (key in class_config_dict and value != class_config_dict[key])
            ):
705
706
                serializable_config_dict[key] = value

707
708
        self.dict_torch_dtype_to_str(serializable_config_dict)

709
710
        return serializable_config_dict

711
    def to_dict(self) -> Dict[str, Any]:
Lysandre's avatar
Lysandre committed
712
713
714
715
        """
        Serializes this instance to a Python dictionary.

        Returns:
716
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
Lysandre's avatar
Lysandre committed
717
        """
718
        output = copy.deepcopy(self.__dict__)
719
720
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
721
722
723
724

        # Transformers version when serializing the model
        output["transformers_version"] = __version__

725
726
        self.dict_torch_dtype_to_str(output)

727
728
        return output

729
    def to_json_string(self, use_diff: bool = True) -> str:
Lysandre's avatar
Lysandre committed
730
731
732
        """
        Serializes this instance to a JSON string.

733
        Args:
734
            use_diff (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
735
736
                If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
                is serialized to JSON string.
737

Lysandre's avatar
Lysandre committed
738
        Returns:
739
            `str`: String containing all the attributes that make up this configuration instance in JSON format.
Lysandre's avatar
Lysandre committed
740
        """
741
742
743
744
745
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
746

747
    def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
Lysandre's avatar
Lysandre committed
748
        """
749
        Save this instance to a JSON file.
Lysandre's avatar
Lysandre committed
750
751

        Args:
752
            json_file_path (`str` or `os.PathLike`):
Lysandre's avatar
Lysandre committed
753
                Path to the JSON file in which this configuration instance's parameters will be saved.
754
            use_diff (`bool`, *optional*, defaults to `True`):
Sylvain Gugger's avatar
Sylvain Gugger committed
755
756
                If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
                is serialized to JSON file.
Lysandre's avatar
Lysandre committed
757
        """
758
        with open(json_file_path, "w", encoding="utf-8") as writer:
759
            writer.write(self.to_json_string(use_diff=use_diff))
760

761
    def update(self, config_dict: Dict[str, Any]):
762
        """
763
        Updates attributes of this class with attributes from `config_dict`.
764
765

        Args:
766
            config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
767
768
769
        """
        for key, value in config_dict.items():
            setattr(self, key, value)
770
771
772

    def update_from_string(self, update_str: str):
        """
773
        Updates attributes of this class with attributes from `update_str`.
774

775
        The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
776
777
778
779
780
        "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"

        The keys to change have to already exist in the config object.

        Args:
781
            update_str (`str`): String with attributes that should be updated for this class.
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807

        """

        d = dict(x.split("=") for x in update_str.split(","))
        for k, v in d.items():
            if not hasattr(self, k):
                raise ValueError(f"key {k} isn't in the original config dict")

            old_v = getattr(self, k)
            if isinstance(old_v, bool):
                if v.lower() in ["true", "1", "y", "yes"]:
                    v = True
                elif v.lower() in ["false", "0", "n", "no"]:
                    v = False
                else:
                    raise ValueError(f"can't derive true or false from {v} (key {k})")
            elif isinstance(old_v, int):
                v = int(v)
            elif isinstance(old_v, float):
                v = float(v)
            elif not isinstance(old_v, str):
                raise ValueError(
                    f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
                )

            setattr(self, k, v)
808

809
810
    def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
        """
811
        Checks whether the passed dictionary has a *torch_dtype* key and if it's not None, converts torch.dtype to a
Sylvain Gugger's avatar
Sylvain Gugger committed
812
813
        string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be
        stored in the json format.
814
815
816
817
        """
        if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
            d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]

818

819
820
821
822
823
824
825
826
827
828
def get_configuration_file(
    path_or_repo: Union[str, os.PathLike],
    revision: Optional[str] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    local_files_only: bool = False,
) -> str:
    """
    Get the configuration file to use for this version of transformers.

    Args:
829
830
831
        path_or_repo (`str` or `os.PathLike`):
            Can be either the id of a repo on huggingface.co or a path to a *directory*.
        revision(`str`, *optional*, defaults to `"main"`):
832
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
833
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
834
            identifier allowed by git.
835
        use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
836
837
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
            when running `transformers-cli login` (stored in `~/.huggingface`).
838
        local_files_only (`bool`, *optional*, defaults to `False`):
839
840
841
            Whether or not to only rely on local files and not to attempt to download any files.

    Returns:
842
        `str`: The configuration file to use.
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
    """
    # Inspect all files from the repo/folder.
    all_files = get_list_of_files(
        path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
    )
    configuration_files_map = {}
    for file_name in all_files:
        search = _re_configuration_file.search(file_name)
        if search is not None:
            v = search.groups()[0]
            configuration_files_map[v] = file_name
    available_versions = sorted(configuration_files_map.keys())

    # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
    configuration_file = FULL_CONFIGURATION_FILE
    transformers_version = version.parse(__version__)
    for v in available_versions:
        if version.parse(v) <= transformers_version:
            configuration_file = configuration_files_map[v]
        else:
            # No point going further since the versions are sorted.
            break

    return configuration_file


869
870
871
872
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
    object="config", object_class="AutoConfig", object_files="configuration file"
)