"tests/vscode:/vscode.git/clone" did not exist on "693cb828ff5e0e530db4f054092b6687439ede15"
modeling_flax_utils.py 32.2 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
17
from functools import partial
18
from pickle import UnpicklingError
Suraj Patil's avatar
Suraj Patil committed
19
from typing import Any, Dict, Set, Tuple, Union
20
21
22
23

import flax.linen as nn
import jax
import jax.numpy as jnp
24
import msgpack.exceptions
25
from flax.core.frozen_dict import FrozenDict, unfreeze
26
27
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
28
29
30
from jax.random import PRNGKey

from .configuration_utils import PretrainedConfig
31
32
33
from .file_utils import (
    FLAX_WEIGHTS_NAME,
    WEIGHTS_NAME,
Sylvain Gugger's avatar
Sylvain Gugger committed
34
    PushToHubMixin,
35
    add_code_sample_docstrings,
36
37
38
39
40
41
    add_start_docstrings_to_model_forward,
    cached_path,
    copy_func,
    hf_bucket_url,
    is_offline_mode,
    is_remote_url,
42
    replace_return_docstrings,
43
)
Patrick von Platen's avatar
Patrick von Platen committed
44
from .generation_flax_utils import FlaxGenerationMixin
45
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
46
47
48
49
50
51
from .utils import logging


logger = logging.get_logger(__name__)


Suraj Patil's avatar
Suraj Patil committed
52
53
54
55
def quick_gelu(x):
    return x * jax.nn.sigmoid(1.702 * x)


56
ACT2FN = {
57
    "gelu": partial(nn.gelu, approximate=False),
58
    "relu": nn.relu,
TFUsers's avatar
TFUsers committed
59
    "silu": nn.swish,
60
    "swish": nn.swish,
61
    "gelu_new": partial(nn.gelu, approximate=True),
Suraj Patil's avatar
Suraj Patil committed
62
    "quick_gelu": quick_gelu,
63
64
65
}


Patrick von Platen's avatar
Patrick von Platen committed
66
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
67
68
69
    r"""
    Base class for all models.

Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
    [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
    downloading and saving models.
72
73
74

    Class attributes (overridden by derived classes):

Sylvain Gugger's avatar
Sylvain Gugger committed
75
76
77
78
79
80
        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
          for this model architecture.
        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
          classes of the same architecture adding modules on top of the base model.
        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
          models, `pixel_values` for vision models and `input_values` for speech models).
81
    """
82
83
    config_class = None
    base_model_prefix = ""
84
    main_input_name = "input_ids"
85

86
    def __init__(
87
88
89
90
91
92
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
93
    ):
94
95
96
        if config is None:
            raise ValueError("config cannot be None")

97
98
        if module is None:
            raise ValueError("module cannot be None")
99
100
101
102
103
104
105

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
106
        self.dtype = dtype
107

108
        # randomly initialized parameters
109
        random_params = self.init_weights(self.key, input_shape)
110
111
112
113
114

        # save required_params as set
        self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
        self.params = random_params

115
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
116
117
        raise NotImplementedError(f"init method has to be implemented for {self}")

118
119
120
121
122
123
124
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

125
126
127
128
129
130
131
    @property
    def framework(self) -> str:
        """
        :str: Identifies that this is a Flax model.
        """
        return "flax"

132
133
134
135
    @property
    def config(self) -> PretrainedConfig:
        return self._config

136
137
138
139
    @property
    def module(self) -> nn.Module:
        return self._module

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    @property
    def params(self) -> Union[Dict, FrozenDict]:
        return self._params

    @property
    def required_params(self) -> Set:
        return self._required_params

    @params.setter
    def params(self, params: Union[Dict, FrozenDict]):
        if isinstance(params, FrozenDict):
            params = unfreeze(params)
        param_keys = set(flatten_dict(params).keys())
        if len(self.required_params - param_keys) > 0:
            raise ValueError(
                "Some parameters are missing. Make sure that `params` include the following "
                f"parameters {self.required_params - param_keys}"
            )
158
        self._params = params
159

Suraj Patil's avatar
Suraj Patil committed
160
161
    def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
        """
162
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
Suraj Patil's avatar
Suraj Patil committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        """

        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        if mask is None:
            return jax.tree_map(conditional_cast, params)

        flat_params = flatten_dict(params)
        flat_mask, _ = jax.tree_flatten(mask)

        for masked, key in zip(flat_mask, flat_params.keys()):
            if masked:
                param = flat_params[key]
                flat_params[key] = conditional_cast(param)

        return unflatten_dict(flat_params)

    def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
186
187
        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
        the `params` in place.
Suraj Patil's avatar
Suraj Patil committed
188
189
190
191
192

        This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

        Arguments:
193
194
195
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
196
197
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip.
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        Examples:

        ```python
        >>> from transformers import FlaxBertModel
        >>> # load model
        >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
        >>> model.params = model.to_bf16(model.params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util
        >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
        >>> flat_params = traverse_util.flatten_dict(model.params)
        >>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> model.params = model.to_bf16(model.params, mask)
        ```"""
Suraj Patil's avatar
Suraj Patil committed
216
217
218
219
        return self._cast_floating_to(params, jnp.bfloat16, mask)

    def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
220
        Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
Sylvain Gugger's avatar
Sylvain Gugger committed
221
        model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
Suraj Patil's avatar
Suraj Patil committed
222
223

        Arguments:
224
225
226
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
227
228
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip
229
230
231
232
233
234
235
236
237
238
239
240
241

        Examples:

        ```python
        >>> from transformers import FlaxBertModel
        >>> # Download model and configuration from huggingface.co
        >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
        >>> # By default, the model params will be in fp32, to illustrate the use of this method,
        >>> # we'll first cast to fp16 and back to fp32
        >>> model.params = model.to_f16(model.params)
        >>> # now cast back to fp32
        >>> model.params = model.to_fp32(model.params)
        ```"""
Suraj Patil's avatar
Suraj Patil committed
242
243
244
245
        return self._cast_floating_to(params, jnp.float32, mask)

    def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
246
247
        Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
        `params` in place.
Suraj Patil's avatar
Suraj Patil committed
248
249
250
251
252

        This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
        half-precision training or to save weights in float16 for inference in order to save memory and improve speed.

        Arguments:
253
254
255
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
256
257
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        Examples:

        ```python
        >>> from transformers import FlaxBertModel
        >>> # load model
        >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
        >>> # By default, the model params will be in fp32, to cast these to float16
        >>> model.params = model.to_fp16(model.params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util
        >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
        >>> flat_params = traverse_util.flatten_dict(model.params)
        >>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> model.params = model.to_fp16(model.params, mask)
        ```"""
Suraj Patil's avatar
Suraj Patil committed
276
277
        return self._cast_floating_to(params, jnp.float16, mask)

278
    @classmethod
279
280
281
282
283
284
285
286
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        dtype: jnp.dtype = jnp.float32,
        *model_args,
        **kwargs
    ):

287
        r"""
288
289
        Instantiate a pretrained flax model from a pre-trained model configuration.

290
        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
291
292
293
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

294
        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
295
296
297
        weights are discarded.

        Parameters:
298
            pretrained_model_name_or_path (`str` or `os.PathLike`):
299
300
                Can be either:

301
                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
304
305
                    - A path to a *directory* containing model weights saved using
                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
Sylvain Gugger's avatar
Sylvain Gugger committed
306
307
                    - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
                      `from_pt` should be set to `True`.
308
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
Sylvain Gugger's avatar
Sylvain Gugger committed
309
310
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs).
Suraj Patil's avatar
Suraj Patil committed
311
312

                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
313
                specified all the computation will be performed with the given `dtype`.
Suraj Patil's avatar
Suraj Patil committed
314
315
316
317

                **Note that this only specifies the dtype of the computation and does not influence the dtype of model
                parameters.**

Sylvain Gugger's avatar
Sylvain Gugger committed
318
                If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
319
320
321
322
                [`~FlaxPreTrainedModel.to_bf16`].
            model_args (sequence of positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
323
324
                Can be either:

325
326
                    - an instance of a class derived from [`PretrainedConfig`],
                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
327

328
                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
329
330
                be automatically loaded when:

331
                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
332
                      model).
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
                      save directory.
335
336
337
                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
                      configuration JSON file named *config.json* is found in the directory.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
338
339
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
340
            from_pt (`bool`, *optional*, defaults to `False`):
341
                Load the model weights from a PyTorch checkpoint save file (see docstring of
342
343
                `pretrained_model_name_or_path` argument).
            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
344
345
346
                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
                checkpoint with 3 labels).
347
            force_download (`bool`, *optional*, defaults to `False`):
348
349
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
350
            resume_download (`bool`, *optional*, defaults to `False`):
351
352
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
353
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
354
355
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
356
            local_files_only(`bool`, *optional*, defaults to `False`):
357
                Whether or not to only look at local files (i.e., do not try to download the model).
358
            revision(`str`, *optional*, defaults to `"main"`):
359
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
360
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
361
                identifier allowed by git.
362
            kwargs (remaining dictionary of keyword arguments, *optional*):
363
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
364
                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
365
366
                automatically loaded:

367
368
                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
369
                      already been done)
370
                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
Sylvain Gugger's avatar
Sylvain Gugger committed
371
372
373
374
                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
                      corresponds to a configuration attribute will be used to override said attribute with the
                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
                      will be passed to the underlying model's `__init__` function.
375
376
377
378
379
380
381
382
383
384
385
386
387

        Examples:

        ```python
        >>> from transformers import BertConfig, FlaxBertModel
        >>> # Download model and configuration from huggingface.co and cache.
        >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
        >>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
        >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
        >>> config = BertConfig.from_json_file('./pt_model/config.json')
        >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
        ```"""
388
389
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
390
        from_pt = kwargs.pop("from_pt", False)
391
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
392
393
394
395
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
396
        use_auth_token = kwargs.pop("use_auth_token", None)
Julien Chaumond's avatar
Julien Chaumond committed
397
        revision = kwargs.pop("revision", None)
398
399
400
401
402
403
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline
404

405
406
407
408
        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

409
410
411
412
413
414
415
416
417
418
419
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
420
                use_auth_token=use_auth_token,
Julien Chaumond's avatar
Julien Chaumond committed
421
                revision=revision,
422
423
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
424
425
426
427
428
                **kwargs,
            )
        else:
            model_kwargs = kwargs

429
430
431
        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

432
433
        # Load model
        if pretrained_model_name_or_path is not None:
434
435
436
437
438
439
440
441
442
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
443
444
                        f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
                        f"{pretrained_model_name_or_path} or `from_pt` set to False"
445
446
                    )
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
447
448
                archive_file = pretrained_model_name_or_path
            else:
449
450
451
452
453
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
                    revision=revision,
                )
454
455
456
457
458
459
460
461
462
463

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
464
                    use_auth_token=use_auth_token,
465
                    user_agent=user_agent,
466
                )
Julien Chaumond's avatar
Julien Chaumond committed
467
468
469
470
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
471
472
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
                    f"  (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
Julien Chaumond's avatar
Julien Chaumond committed
473
474
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
                )
475
476
477
478
479
480
481
482
483
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
        else:
            resolved_archive_file = None

484
485
        # init random models
        model = cls(config, *model_args, **model_kwargs)
486

487
488
489
490
491
492
        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
493
494
495
496
497
498
499
500
501
502
503
504
505
                except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                    try:
                        with open(resolved_archive_file) as f:
                            if f.read().startswith("version"):
                                raise OSError(
                                    "You seem to have cloned a repository without having git-lfs installed. Please install "
                                    "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                                    "you cloned."
                                )
                            else:
                                raise ValueError from e
                    except (UnicodeDecodeError, ValueError):
                        raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
506
507
508
509
            # make sure all arrays are stored as jnp.arrays
            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
            # https://github.com/google/flax/issues/1261
            state = jax.tree_util.tree_map(jnp.array, state)
510

511
512
513
        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]
514

515
516
517
518
519
        # if model is head model and we are loading weights from base model
        # we initialize new params dict with base_model_prefix
        if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
            state = {cls.base_model_prefix: state}

520
521
        # flatten dicts
        state = flatten_dict(state)
522

523
        random_state = flatten_dict(unfreeze(model.params))
524

525
526
527
        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
        # matching the weights in the model.
        mismatched_keys = []
        for key in state.keys():
            if key in random_state and state[key].shape != random_state[key].shape:
                if ignore_mismatched_sizes:
                    mismatched_keys.append((key, state[key].shape, random_state[key].shape))
                    state[key] = random_state[key]
                else:
                    raise ValueError(
                        f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
                        f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
                        "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
                        "model."
                    )

544
545
546
547
        # add missing keys as random parameters
        for missing_key in missing_keys:
            state[missing_key] = random_state[missing_key]

548
549
550
551
        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
570
        elif len(mismatched_keys) == 0:
571
572
573
574
575
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )
576
577
578
579
580
581
582
583
584
585
586
587
        if len(mismatched_keys) > 0:
            mismatched_warning = "\n".join(
                [
                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                    for key, shape1, shape2 in mismatched_keys
                ]
            )
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
588
589
590

        # set correct parameters
        model.params = unflatten_dict(state)
591

592
593
        return model

594
    def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
595
596
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
597
        `[`~FlaxPreTrainedModel.from_pretrained`]` class method
598
599

        Arguments:
600
            save_directory (`str` or `os.PathLike`):
601
                Directory to which to save. Will be created if it doesn't exist.
602
            push_to_hub (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
603
                Whether or not to push your model to the Hugging Face model hub after saving it.
604

605
606
                <Tip warning={true}>

Sylvain Gugger's avatar
Sylvain Gugger committed
607
608
609
                Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
                which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
                folder. Pass along `temp_dir=True` to use a temporary directory instead.
610

611
                </Tip>
612

Sylvain Gugger's avatar
Sylvain Gugger committed
613
            kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
614
                Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
615
616
        """
        if os.path.isfile(save_directory):
617
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
618
            return
619
620
621
622
623

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

624
625
626
627
628
        os.makedirs(save_directory, exist_ok=True)

        # get abs dir
        save_directory = os.path.abspath(save_directory)
        # save config as well
629
        self.config.architectures = [self.__class__.__name__[4:]]
630
631
632
        self.config.save_pretrained(save_directory)

        # save model
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        with open(output_model_file, "wb") as f:
635
636
            params = params if params is not None else self.params
            model_bytes = to_bytes(params)
637
            f.write(model_bytes)
638

Sylvain Gugger's avatar
Sylvain Gugger committed
639
640
641
        logger.info(f"Model weights saved in {output_model_file}")

        if push_to_hub:
642
            url = self._push_to_hub(repo, commit_message=commit_message)
Sylvain Gugger's avatar
Sylvain Gugger committed
643
644
            logger.info(f"Model pushed to the hub in this commit: {url}")

645

646
647
648
649
650
651
652
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
    object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
)


653
654
655
656
657
658
659
def overwrite_call_docstring(model_class, docstring):
    # copy __call__ function to be sure docstring is changed only for this function
    model_class.__call__ = copy_func(model_class.__call__)
    # delete existing docstring
    model_class.__call__.__doc__ = None
    # set correct docstring
    model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
660
661
662
663
664


def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
    model_class.__call__ = copy_func(model_class.__call__)
    model_class.__call__ = add_code_sample_docstrings(
665
        processor_class=tokenizer_class,
666
667
668
669
670
671
672
673
674
675
676
677
678
        checkpoint=checkpoint,
        output_type=output_type,
        config_class=config_class,
        model_cls=model_class.__name__,
    )(model_class.__call__)


def append_replace_return_docstrings(model_class, output_type, config_class):
    model_class.__call__ = copy_func(model_class.__call__)
    model_class.__call__ = replace_return_docstrings(
        output_type=output_type,
        config_class=config_class,
    )(model_class.__call__)