configuration_utils.py 25.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""


import copy
import json
import logging
import os
23
from typing import Any, Dict, Tuple
24

Aymeric Augustin's avatar
Aymeric Augustin committed
25
26
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url

27
28
29

logger = logging.getLogger(__name__)

30

31
32
class PretrainedConfig(object):
    r""" Base class for all configuration classes.
33
34
        Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving
        configurations.
35
36

        Note:
37
38
            A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
            initialize a model does **not** load the model weights.
39
40
            It only affects the model's configuration.

41
42
43
        Class attributes (overridden by derived classes)
            - **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
              recreate the correct object in :class:`~transformers.AutoConfig`.
44

Lysandre's avatar
Lysandre committed
45
        Args:
46
            output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
47
                Whether or not the model should return all hidden-states.
48
            output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
49
                Whether or not the model should returns all attentions.
Teven's avatar
Teven committed
50
            use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
51
                Whether or not the model should return the last key/values attentions (not used by all models).
52
53
54
            return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not the model should return a :class:`~transformers.file_utils.ModelOutput` instead of a
                plain tuple.
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether the model is used as an encoder/decoder or not.
            is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether the model is used as decoder or not (in which case it's used as an encoder).
            prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
                Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
                of heads to prune in said layer.

                For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer
                2.
            xla_device (:obj:`bool`, `optional`):
                A flag to indicate if TPU are available or not.

        Parameters for sequence generation
            - **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
              default in the :obj:`generate` method of the model.
            - **min_length** (:obj:`int`, `optional`, defaults to 10) -- Minimum length that will be used by
              default in the :obj:`generate` method of the model.
            - **do_sample** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by default in
              the :obj:`generate` method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
            - **early_stopping** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by
              default in the :obj:`generate` method of the model. Whether to stop the beam search when at least
              ``num_beams`` sentences are finished per batch or not.
            - **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be
              used by default in the :obj:`generate` method of the model. 1 means no beam search.
            - **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
              probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
              positive.
            - **top_k** (:obj:`int`, `optional`, defaults to 50) -- Number of highest probability vocabulary tokens to
              keep for top-k-filtering that will be used by default in the :obj:`generate` method of the model.
            - **top_p** (:obj:`float`, `optional`, defaults to 1) --  Value that will be used by default in the
              :obj:`generate` method of the model for ``top_p``. If set to float < 1, only the most probable tokens
              with probabilities that add up to ``top_p`` or highest are kept for generation.
            - **repetition_penalty** (:obj:`float`, `optional`, defaults to 1) -- Parameter for repetition penalty
              that will be used by default in the :obj:`generate` method of the model. 1.0 means no penalty.
            - **length_penalty** (:obj:`float`, `optional`, defaults to 1) -- Exponential penalty to the length that
              will be used by default in the :obj:`generate` method of the model.
            - **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default
              in the :obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of
              that size can only occur once.
            - **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be
              generated that will be used by default in the :obj:`generate` method of the model. In order to get the
              tokens of the words that should not appear in the generated text, use
              :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
            - **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed
              returned sequences for each element in the batch that will be used by default in the :obj:`generate`
              method of the model.

        Parameters for fine-tuning tasks
Sylvain Gugger's avatar
Sylvain Gugger committed
104
            - **architectures** (:obj:`List[str]`, `optional`) -- Model architectures that can be used with the
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
              model pretrained weights.
            - **finetuning_task** (:obj:`str`, `optional`) -- Name of the task used to fine-tune the model. This can be
              used when converting from an original (TensorFlow or PyTorch) checkpoint.
            - **id2label** (:obj:`List[str]`, `optional`) -- A map from index (for instance prediction index, or target
              index) to label.
            - **label2id** (:obj:`Dict[str, int]`, `optional`) -- A map from label to index for the model.
            - **num_labels** (:obj:`int`, `optional`) -- Number of labels to use in the last layer added to the model,
              typically for a classification task.
            - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for
              the current task.

        Parameters linked to the tokenizer
            - **prefix** (:obj:`str`, `optional`) -- A specific prompt that should be added at the beginning of each
              text before calling the model.
            - **bos_token_id** (:obj:`int`, `optional`)) -- The id of the `beginning-of-stream` token.
            - **pad_token_id** (:obj:`int`, `optional`)) -- The id of the `padding` token.
            - **eos_token_id** (:obj:`int`, `optional`)) -- The id of the `end-of-stream` token.
            - **decoder_start_token_id** (:obj:`int`, `optional`)) -- If an encoder-decoder model starts decoding with
              a different token than `bos`, the id of that token.

        PyTorch specific parameters
            - **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
              used with Torchscript.

        TensorFlow specific parameters
            - **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should
              use BFloat16 scalars (only used by some TensorFlow models).
132
    """
133
    model_type: str = ""
134
135

    def __init__(self, **kwargs):
thomwolf's avatar
thomwolf committed
136
        # Attributes with defaults
137
        self.return_dict = kwargs.pop("return_dict", False)
138
        self.output_hidden_states = kwargs.pop("output_hidden_states", False)
139
        self.output_attentions = kwargs.pop("output_attentions", False)
140
        self.use_cache = kwargs.pop("use_cache", True)  # Not used by all models
141
142
143
        self.torchscript = kwargs.pop("torchscript", False)  # Only used by PyTorch models
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
        self.pruned_heads = kwargs.pop("pruned_heads", {})
thomwolf's avatar
thomwolf committed
144
145

        # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
Patrick von Platen's avatar
Patrick von Platen committed
146
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
147
        self.is_decoder = kwargs.pop("is_decoder", False)
148

thomwolf's avatar
thomwolf committed
149
        # Parameters for sequence generation
150
        self.max_length = kwargs.pop("max_length", 20)
Patrick von Platen's avatar
Patrick von Platen committed
151
        self.min_length = kwargs.pop("min_length", 0)
152
        self.do_sample = kwargs.pop("do_sample", False)
Patrick von Platen's avatar
Patrick von Platen committed
153
        self.early_stopping = kwargs.pop("early_stopping", False)
154
155
156
157
158
159
        self.num_beams = kwargs.pop("num_beams", 1)
        self.temperature = kwargs.pop("temperature", 1.0)
        self.top_k = kwargs.pop("top_k", 50)
        self.top_p = kwargs.pop("top_p", 1.0)
        self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
        self.length_penalty = kwargs.pop("length_penalty", 1.0)
Patrick von Platen's avatar
Patrick von Platen committed
160
        self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
161
        self.bad_words_ids = kwargs.pop("bad_words_ids", None)
162
        self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
thomwolf's avatar
thomwolf committed
163

thomwolf's avatar
thomwolf committed
164
        # Fine-tuning task arguments
Julien Chaumond's avatar
Julien Chaumond committed
165
        self.architectures = kwargs.pop("architectures", None)
166
        self.finetuning_task = kwargs.pop("finetuning_task", None)
167
168
169
        self.id2label = kwargs.pop("id2label", None)
        self.label2id = kwargs.pop("label2id", None)
        if self.id2label is not None:
170
            kwargs.pop("num_labels", None)
171
172
173
174
            self.id2label = dict((int(key), value) for key, value in self.id2label.items())
            # Keys are always strings in JSON so convert ids to int here.
        else:
            self.num_labels = kwargs.pop("num_labels", 2)
thomwolf's avatar
thomwolf committed
175

176
177
178
179
180
181
182
183
184
185
        # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
        self.prefix = kwargs.pop("prefix", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
        self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

        # task specific arguments
        self.task_specific_params = kwargs.pop("task_specific_params", None)

186
187
188
        # TPU arguments
        self.xla_device = kwargs.pop("xla_device", None)

thomwolf's avatar
thomwolf committed
189
190
191
192
193
194
195
196
        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error("Can't set {} with value {} for {}".format(key, value, self))
                raise err

197
    @property
198
    def use_return_dict(self) -> bool:
199
        """
200
        :obj:`bool`: Whether or not return :class:`~transformers.file_utils.ModelOutput` instead of tuples.
201
        """
202
203
        # If torchscript is set, force `return_dict=False` to avoid jit errors
        return self.return_dict and not self.torchscript
204

205
    @property
206
    def num_labels(self) -> int:
207
208
209
        """
        :obj:`int`: The number of labels for classification models.
        """
210
        return len(self.id2label)
211
212

    @num_labels.setter
213
    def num_labels(self, num_labels: int):
214
        self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
215
216
        self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))

217
    def save_pretrained(self, save_directory: str):
Lysandre's avatar
Lysandre committed
218
        """
219
220
        Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
        :func:`~transformers.PretrainedConfig.from_pretrained` class method.
Lysandre's avatar
Lysandre committed
221
222

        Args:
223
224
            save_directory (:obj:`str`):
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
225
        """
226
227
228
        if os.path.isfile(save_directory):
            raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
        os.makedirs(save_directory, exist_ok=True)
229
230
231
        # If we save using the predefined names, we can load using `from_pretrained`
        output_config_file = os.path.join(save_directory, CONFIG_NAME)

232
        self.to_json_file(output_config_file, use_diff=True)
thomwolf's avatar
thomwolf committed
233
        logger.info("Configuration saved in {}".format(output_config_file))
234
235

    @classmethod
236
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
237
        r"""
238
239
        Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
        configuration.
Lysandre's avatar
Lysandre committed
240
241

        Args:
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            pretrained_model_name_or_path (:obj:`str`):
                This can be either:

                - the `shortcut name` of a pretrained model configuration to load from cache or download, e.g.,
                  ``bert-base-uncased``.
                - the `identifier name` of a pretrained model configuration that was uploaded to our S3 by any user,
                  e.g., ``dbmdz/bert-base-german-cased``.
                - a path to a `directory` containing a configuration file saved using the
                  :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
                - a path or url to a saved configuration JSON `file`, e.g.,
                  ``./my_model_directory/configuration.json``.
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
Lysandre's avatar
Lysandre committed
256
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
257
258
                Wheter or not to force to (re-)download the configuration files and override the cached versions if they
                exist.
Lysandre's avatar
Lysandre committed
259
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
260
261
262
263
                Whether or not to delete incompletely received file. Attempts to resume the download if such a file
                exists.
            proxies (:obj:`Dict[str, str]`, `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g.,
Lysandre's avatar
Lysandre committed
264
                :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
265
                The proxies are used on each request.
266
267
268
269
270
271
272
273
274
275
            return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
                If :obj:`False`, then this function returns just the final configuration object.

                If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
                is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
                the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
            kwargs (:obj:`Dict[str, Any]`, `optional`):
                The values in kwargs of any keys which are configuration attributes will be used to override the loaded
                values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
                controlled by the ``return_unused_kwargs`` keyword parameter.
276

Lysandre's avatar
Lysandre committed
277
        Returns:
278
            :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        Examples::

            # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
            # derived class: BertConfig
            config = BertConfig.from_pretrained('bert-base-uncased')    # Download configuration from S3 and cache.
            config = BertConfig.from_pretrained('./test/saved_model/')  # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
            config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
            config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
            assert config.output_attention == True
            config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
                                                               foo=False, return_unused_kwargs=True)
            assert config.output_attention == True
            assert unused_kwargs == {'foo': False}

        """
Julien Chaumond's avatar
Julien Chaumond committed
295
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
296
297
298
        return cls.from_dict(config_dict, **kwargs)

    @classmethod
299
    def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
300
        """
301
302
        From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used
        for instantiating a :class:`~transformers.PretrainedConfig` using ``from_dict``.
303
304

        Parameters:
305
            pretrained_model_name_or_path (:obj:`str`):
Lysandre's avatar
Lysandre committed
306
307
308
                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

        Returns:
309
            :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
Lysandre's avatar
Lysandre committed
310

311
        """
312
313
314
315
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
316
        local_files_only = kwargs.pop("local_files_only", False)
317

318
        if os.path.isdir(pretrained_model_name_or_path):
319
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
320
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
321
            config_file = pretrained_model_name_or_path
322
        else:
Julien Chaumond's avatar
Julien Chaumond committed
323
            config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
324

325
        try:
326
            # Load from URL or cache if already cached
327
328
329
330
331
332
            resolved_config_file = cached_path(
                config_file,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                resume_download=resume_download,
333
                local_files_only=local_files_only,
334
            )
335
            # Load config dict
336
337
            if resolved_config_file is None:
                raise EnvironmentError
338
            config_dict = cls._dict_from_json_file(resolved_config_file)
339

thomwolf's avatar
thomwolf committed
340
        except EnvironmentError:
341
342
343
344
345
            msg = (
                f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
            )
thomwolf's avatar
thomwolf committed
346
347
            raise EnvironmentError(msg)

348
        except json.JSONDecodeError:
349
350
351
352
353
            msg = (
                "Couldn't reach server at '{}' to download configuration file or "
                "configuration file is not a valid JSON file. "
                "Please check network or file content here: {}.".format(config_file, resolved_config_file)
            )
354
355
            raise EnvironmentError(msg)

356
357
358
        if resolved_config_file == config_file:
            logger.info("loading configuration file {}".format(config_file))
        else:
359
            logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
360

361
362
363
        return config_dict, kwargs

    @classmethod
364
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
365
        """
366
        Instantiates a :class:`~transformers.PretrainedConfig` from a Python dictionary of parameters.
Lysandre's avatar
Lysandre committed
367
368

        Args:
369
370
371
372
373
            config_dict (:obj:`Dict[str, Any]`):
                Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
                retrieved from a pretrained checkpoint by leveraging the
                :func:`~transformers.PretrainedConfig.get_config_dict` method.
            kwargs (:obj:`Dict[str, Any]`):
Lysandre's avatar
Lysandre committed
374
375
376
                Additional parameters from which to initialize the configuration object.

        Returns:
377
            :class:`PretrainedConfig`: The configuration object instantiated from those parameters.
Lysandre's avatar
Lysandre committed
378
        """
379
380
381
382
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        config = cls(**config_dict)

383
        if hasattr(config, "pruned_heads"):
384
            config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
385
386
387
388
389
390
391
392
393
394

        # Update config with kwargs if needed
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

395
        logger.info("Model config %s", str(config))
396
397
398
399
400
401
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config

    @classmethod
Lysandre's avatar
Lysandre committed
402
    def from_json_file(cls, json_file: str) -> "PretrainedConfig":
Lysandre's avatar
Lysandre committed
403
        """
404
        Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
Lysandre's avatar
Lysandre committed
405
406

        Args:
407
            json_file (:obj:`str`):
Lysandre's avatar
Lysandre committed
408
409
410
                Path to the JSON file containing the parameters.

        Returns:
411
            :class:`PretrainedConfig`: The configuration object instantiated from that JSON file.
Lysandre's avatar
Lysandre committed
412
413

        """
414
415
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)
416
417

    @classmethod
418
    def _dict_from_json_file(cls, json_file: str):
419
        with open(json_file, "r", encoding="utf-8") as reader:
420
            text = reader.read()
421
        return json.loads(text)
422
423
424
425
426

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __repr__(self):
427
        return "{} {}".format(self.__class__.__name__, self.to_json_string())
428

429
    def to_diff_dict(self) -> Dict[str, Any]:
430
431
432
433
434
435
        """
        Removes all attributes from config which correspond to the default
        config attributes for better readability and serializes to a Python
        dictionary.

        Returns:
436
            :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        """
        config_dict = self.to_dict()

        # get the default config dict
        default_config_dict = PretrainedConfig().to_dict()

        serializable_config_dict = {}

        # only serialize values that differ from the default config
        for key, value in config_dict.items():
            if key not in default_config_dict or value != default_config_dict[key]:
                serializable_config_dict[key] = value

        return serializable_config_dict

452
    def to_dict(self) -> Dict[str, Any]:
Lysandre's avatar
Lysandre committed
453
454
455
456
        """
        Serializes this instance to a Python dictionary.

        Returns:
457
            :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
Lysandre's avatar
Lysandre committed
458
        """
459
        output = copy.deepcopy(self.__dict__)
460
461
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
462
463
        return output

464
    def to_json_string(self, use_diff: bool = True) -> str:
Lysandre's avatar
Lysandre committed
465
466
467
        """
        Serializes this instance to a JSON string.

468
        Args:
469
470
471
            use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
                If set to ``True``, only the difference between the config instance and the default
                ``PretrainedConfig()`` is serialized to JSON string.
472

Lysandre's avatar
Lysandre committed
473
        Returns:
474
            :obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
Lysandre's avatar
Lysandre committed
475
        """
476
477
478
479
480
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
481

482
    def to_json_file(self, json_file_path: str, use_diff: bool = True):
Lysandre's avatar
Lysandre committed
483
        """
484
        Save this instance to a JSON file.
Lysandre's avatar
Lysandre committed
485
486

        Args:
487
            json_file_path (:obj:`str`):
Lysandre's avatar
Lysandre committed
488
                Path to the JSON file in which this configuration instance's parameters will be saved.
489
490
491
            use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
                If set to ``True``, only the difference between the config instance and the default
                ``PretrainedConfig()`` is serialized to JSON file.
Lysandre's avatar
Lysandre committed
492
        """
493
        with open(json_file_path, "w", encoding="utf-8") as writer:
494
            writer.write(self.to_json_string(use_diff=use_diff))
495

496
    def update(self, config_dict: Dict[str, Any]):
497
        """
498
        Updates attributes of this class with attributes from ``config_dict``.
499
500

        Args:
501
            config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
502
503
504
        """
        for key, value in config_dict.items():
            setattr(self, key, value)