modeling_flax_utils.py 36.5 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
17
from functools import partial
18
from pickle import UnpicklingError
Suraj Patil's avatar
Suraj Patil committed
19
from typing import Any, Dict, Set, Tuple, Union
20
21
22
23

import flax.linen as nn
import jax
import jax.numpy as jnp
24
import msgpack.exceptions
25
from flax.core.frozen_dict import FrozenDict, unfreeze
26
27
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
28
from jax.random import PRNGKey
29
from requests import HTTPError
30
31

from .configuration_utils import PretrainedConfig
32
from .dynamic_module_utils import custom_object_save
33
34
35
from .file_utils import (
    FLAX_WEIGHTS_NAME,
    WEIGHTS_NAME,
36
    EntryNotFoundError,
Sylvain Gugger's avatar
Sylvain Gugger committed
37
    PushToHubMixin,
38
39
    RepositoryNotFoundError,
    RevisionNotFoundError,
40
    add_code_sample_docstrings,
41
42
43
    add_start_docstrings_to_model_forward,
    cached_path,
    copy_func,
44
    has_file,
45
46
47
    hf_bucket_url,
    is_offline_mode,
    is_remote_url,
48
    replace_return_docstrings,
49
)
Patrick von Platen's avatar
Patrick von Platen committed
50
from .generation_flax_utils import FlaxGenerationMixin
51
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
52
53
54
55
56
57
from .utils import logging


logger = logging.get_logger(__name__)


Suraj Patil's avatar
Suraj Patil committed
58
59
60
61
def quick_gelu(x):
    return x * jax.nn.sigmoid(1.702 * x)


62
ACT2FN = {
63
    "gelu": partial(nn.gelu, approximate=False),
64
    "relu": nn.relu,
TFUsers's avatar
TFUsers committed
65
    "silu": nn.swish,
66
    "swish": nn.swish,
67
    "gelu_new": partial(nn.gelu, approximate=True),
Suraj Patil's avatar
Suraj Patil committed
68
    "quick_gelu": quick_gelu,
69
70
71
}


Patrick von Platen's avatar
Patrick von Platen committed
72
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
73
74
75
    r"""
    Base class for all models.

Sylvain Gugger's avatar
Sylvain Gugger committed
76
77
    [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
    downloading and saving models.
78
79
80

    Class attributes (overridden by derived classes):

Sylvain Gugger's avatar
Sylvain Gugger committed
81
82
83
84
85
86
        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
          for this model architecture.
        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
          classes of the same architecture adding modules on top of the base model.
        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
          models, `pixel_values` for vision models and `input_values` for speech models).
87
    """
88
89
    config_class = None
    base_model_prefix = ""
90
    main_input_name = "input_ids"
91
    _auto_class = None
92

93
    def __init__(
94
95
96
97
98
99
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
100
    ):
101
102
103
        if config is None:
            raise ValueError("config cannot be None")

104
105
        if module is None:
            raise ValueError("module cannot be None")
106
107
108
109
110
111
112

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
113
        self.dtype = dtype
114

115
        # randomly initialized parameters
116
        random_params = self.init_weights(self.key, input_shape)
117
118
119
120
121

        # save required_params as set
        self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
        self.params = random_params

122
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
123
124
        raise NotImplementedError(f"init method has to be implemented for {self}")

125
126
127
128
129
130
131
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

132
133
134
135
136
137
138
    @property
    def framework(self) -> str:
        """
        :str: Identifies that this is a Flax model.
        """
        return "flax"

139
140
141
142
    @property
    def config(self) -> PretrainedConfig:
        return self._config

143
144
145
146
    @property
    def module(self) -> nn.Module:
        return self._module

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    @property
    def params(self) -> Union[Dict, FrozenDict]:
        return self._params

    @property
    def required_params(self) -> Set:
        return self._required_params

    @params.setter
    def params(self, params: Union[Dict, FrozenDict]):
        if isinstance(params, FrozenDict):
            params = unfreeze(params)
        param_keys = set(flatten_dict(params).keys())
        if len(self.required_params - param_keys) > 0:
            raise ValueError(
                "Some parameters are missing. Make sure that `params` include the following "
                f"parameters {self.required_params - param_keys}"
            )
165
        self._params = params
166

Suraj Patil's avatar
Suraj Patil committed
167
168
    def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
        """
169
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
Suraj Patil's avatar
Suraj Patil committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        """

        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        if mask is None:
            return jax.tree_map(conditional_cast, params)

        flat_params = flatten_dict(params)
        flat_mask, _ = jax.tree_flatten(mask)

        for masked, key in zip(flat_mask, flat_params.keys()):
            if masked:
                param = flat_params[key]
                flat_params[key] = conditional_cast(param)

        return unflatten_dict(flat_params)

    def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
        the `params` in place.
Suraj Patil's avatar
Suraj Patil committed
195
196
197
198
199

        This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

        Arguments:
200
201
202
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
203
204
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip.
205
206
207
208
209

        Examples:

        ```python
        >>> from transformers import FlaxBertModel
Sylvain Gugger's avatar
Sylvain Gugger committed
210

211
        >>> # load model
Sylvain Gugger's avatar
Sylvain Gugger committed
212
        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
213
214
215
216
217
        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
        >>> model.params = model.to_bf16(model.params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util
Sylvain Gugger's avatar
Sylvain Gugger committed
218
219

        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
220
        >>> flat_params = traverse_util.flatten_dict(model.params)
Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
223
224
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
225
226
227
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> model.params = model.to_bf16(model.params, mask)
        ```"""
Suraj Patil's avatar
Suraj Patil committed
228
229
230
231
        return self._cast_floating_to(params, jnp.bfloat16, mask)

    def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
232
        Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
Sylvain Gugger's avatar
Sylvain Gugger committed
233
        model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
Suraj Patil's avatar
Suraj Patil committed
234
235

        Arguments:
236
237
238
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
239
240
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip
241
242
243
244
245

        Examples:

        ```python
        >>> from transformers import FlaxBertModel
Sylvain Gugger's avatar
Sylvain Gugger committed
246

247
        >>> # Download model and configuration from huggingface.co
Sylvain Gugger's avatar
Sylvain Gugger committed
248
        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
249
250
251
252
253
254
        >>> # By default, the model params will be in fp32, to illustrate the use of this method,
        >>> # we'll first cast to fp16 and back to fp32
        >>> model.params = model.to_f16(model.params)
        >>> # now cast back to fp32
        >>> model.params = model.to_fp32(model.params)
        ```"""
Suraj Patil's avatar
Suraj Patil committed
255
256
257
258
        return self._cast_floating_to(params, jnp.float32, mask)

    def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
259
260
        Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
        `params` in place.
Suraj Patil's avatar
Suraj Patil committed
261
262
263
264
265

        This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
        half-precision training or to save weights in float16 for inference in order to save memory and improve speed.

        Arguments:
266
267
268
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
269
270
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
                you want to cast, and should be `False` for those you want to skip
271
272
273
274
275

        Examples:

        ```python
        >>> from transformers import FlaxBertModel
Sylvain Gugger's avatar
Sylvain Gugger committed
276

277
        >>> # load model
Sylvain Gugger's avatar
Sylvain Gugger committed
278
        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
279
280
281
282
283
        >>> # By default, the model params will be in fp32, to cast these to float16
        >>> model.params = model.to_fp16(model.params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util
Sylvain Gugger's avatar
Sylvain Gugger committed
284
285

        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
286
        >>> flat_params = traverse_util.flatten_dict(model.params)
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288
289
290
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
291
292
293
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> model.params = model.to_fp16(model.params, mask)
        ```"""
Suraj Patil's avatar
Suraj Patil committed
294
295
        return self._cast_floating_to(params, jnp.float16, mask)

296
    @classmethod
297
298
299
300
301
302
303
304
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        dtype: jnp.dtype = jnp.float32,
        *model_args,
        **kwargs
    ):

305
        r"""
306
307
        Instantiate a pretrained flax model from a pre-trained model configuration.

308
        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
309
310
311
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

312
        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
313
314
315
        weights are discarded.

        Parameters:
316
            pretrained_model_name_or_path (`str` or `os.PathLike`):
317
318
                Can be either:

319
                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Sylvain Gugger's avatar
Sylvain Gugger committed
320
321
                      Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                      user or organization name, like `dbmdz/bert-base-german-cased`.
322
323
                    - A path to a *directory* containing model weights saved using
                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
Sylvain Gugger's avatar
Sylvain Gugger committed
324
325
                    - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
                      `from_pt` should be set to `True`.
326
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
Sylvain Gugger's avatar
Sylvain Gugger committed
327
328
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs).
Suraj Patil's avatar
Suraj Patil committed
329
330

                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
331
                specified all the computation will be performed with the given `dtype`.
Suraj Patil's avatar
Suraj Patil committed
332
333
334
335

                **Note that this only specifies the dtype of the computation and does not influence the dtype of model
                parameters.**

Sylvain Gugger's avatar
Sylvain Gugger committed
336
                If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
337
338
339
340
                [`~FlaxPreTrainedModel.to_bf16`].
            model_args (sequence of positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
341
342
                Can be either:

343
344
                    - an instance of a class derived from [`PretrainedConfig`],
                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
345

346
                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
347
348
                be automatically loaded when:

349
                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
350
                      model).
Sylvain Gugger's avatar
Sylvain Gugger committed
351
352
                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
                      save directory.
353
354
355
                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
                      configuration JSON file named *config.json* is found in the directory.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
356
357
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
358
            from_pt (`bool`, *optional*, defaults to `False`):
359
                Load the model weights from a PyTorch checkpoint save file (see docstring of
360
361
                `pretrained_model_name_or_path` argument).
            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
362
363
364
                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
                checkpoint with 3 labels).
365
            force_download (`bool`, *optional*, defaults to `False`):
366
367
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
368
            resume_download (`bool`, *optional*, defaults to `False`):
369
370
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
371
            proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
372
373
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
374
            local_files_only(`bool`, *optional*, defaults to `False`):
375
                Whether or not to only look at local files (i.e., do not try to download the model).
376
            revision (`str`, *optional*, defaults to `"main"`):
377
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
378
                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
379
                identifier allowed by git.
380
            kwargs (remaining dictionary of keyword arguments, *optional*):
381
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
382
                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
383
384
                automatically loaded:

385
386
                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
387
                      already been done)
388
                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
Sylvain Gugger's avatar
Sylvain Gugger committed
389
390
391
392
                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
                      corresponds to a configuration attribute will be used to override said attribute with the
                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
                      will be passed to the underlying model's `__init__` function.
393
394
395
396
397

        Examples:

        ```python
        >>> from transformers import BertConfig, FlaxBertModel
Sylvain Gugger's avatar
Sylvain Gugger committed
398

399
        >>> # Download model and configuration from huggingface.co and cache.
Sylvain Gugger's avatar
Sylvain Gugger committed
400
        >>> model = FlaxBertModel.from_pretrained("bert-base-cased")
401
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
Sylvain Gugger's avatar
Sylvain Gugger committed
402
        >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
403
        >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
Sylvain Gugger's avatar
Sylvain Gugger committed
404
405
        >>> config = BertConfig.from_json_file("./pt_model/config.json")
        >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
406
        ```"""
407
408
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
409
        from_pt = kwargs.pop("from_pt", False)
410
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
411
412
413
414
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
415
        use_auth_token = kwargs.pop("use_auth_token", None)
Julien Chaumond's avatar
Julien Chaumond committed
416
        revision = kwargs.pop("revision", None)
417
418
419
420
421
422
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline
423

424
425
426
427
        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

428
429
430
431
432
433
434
435
436
437
438
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
439
                use_auth_token=use_auth_token,
Julien Chaumond's avatar
Julien Chaumond committed
440
                revision=revision,
441
442
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
443
444
445
446
447
                **kwargs,
            )
        else:
            model_kwargs = kwargs

448
449
450
        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

451
452
        # Load model
        if pretrained_model_name_or_path is not None:
453
454
455
456
457
458
459
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
460
461
462
463
464
465
466
                # At this stage we don't have a weight file so we will raise an error.
                elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
                    raise EnvironmentError(
                        f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
                        "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
                        "weights."
                    )
467
468
                else:
                    raise EnvironmentError(
469
470
                        f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
                        f"{pretrained_model_name_or_path}."
471
472
                    )
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
473
474
                archive_file = pretrained_model_name_or_path
            else:
475
                filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
476
477
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
478
                    filename=filename,
479
480
                    revision=revision,
                )
481
482
483
484
485
486
487
488
489
490

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
491
                    use_auth_token=use_auth_token,
492
                    user_agent=user_agent,
493
                )
494

495
            except RepositoryNotFoundError:
496
497
498
499
500
501
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
                    "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
                    "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
                    "login` and pass `use_auth_token=True`."
                )
502
            except RevisionNotFoundError:
503
504
505
506
507
                raise EnvironmentError(
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
                    "this model name. Check the model page at "
                    f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
                )
508
            except EntryNotFoundError:
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
                if filename == FLAX_WEIGHTS_NAME:
                    has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
                    if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
                        raise EnvironmentError(
                            f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
                            "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
                            "those weights."
                        )
                    else:
                        raise EnvironmentError(
                            f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
                            f"or {WEIGHTS_NAME}."
                        )
                else:
                    raise EnvironmentError(
                        f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
                    )
526
            except HTTPError:
527
528
529
530
531
532
533
                raise EnvironmentError(
                    "We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
                    f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
                    f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
                    "Checkout your internet connection or see how to run the library in offline mode at "
                    "'https://huggingface.co/docs/transformers/installation#offline-mode'."
                )
534
            except EnvironmentError:
535
536
537
538
539
                raise EnvironmentError(
                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
Julien Chaumond's avatar
Julien Chaumond committed
540
                )
541
542
543
544
545
546
547
548

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
        else:
            resolved_archive_file = None

549
550
        # init random models
        model = cls(config, *model_args, **model_kwargs)
551

552
553
554
555
556
557
        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
558
559
560
561
562
563
564
565
566
567
568
569
570
                except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                    try:
                        with open(resolved_archive_file) as f:
                            if f.read().startswith("version"):
                                raise OSError(
                                    "You seem to have cloned a repository without having git-lfs installed. Please install "
                                    "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                                    "you cloned."
                                )
                            else:
                                raise ValueError from e
                    except (UnicodeDecodeError, ValueError):
                        raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
571
572
573
574
            # make sure all arrays are stored as jnp.arrays
            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
            # https://github.com/google/flax/issues/1261
            state = jax.tree_util.tree_map(jnp.array, state)
575

576
577
578
        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]
579

580
581
582
583
584
        # if model is head model and we are loading weights from base model
        # we initialize new params dict with base_model_prefix
        if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
            state = {cls.base_model_prefix: state}

585
586
        # flatten dicts
        state = flatten_dict(state)
587

588
        random_state = flatten_dict(unfreeze(model.params))
589

590
591
592
        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
        # matching the weights in the model.
        mismatched_keys = []
        for key in state.keys():
            if key in random_state and state[key].shape != random_state[key].shape:
                if ignore_mismatched_sizes:
                    mismatched_keys.append((key, state[key].shape, random_state[key].shape))
                    state[key] = random_state[key]
                else:
                    raise ValueError(
                        f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
                        f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
                        "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
                        "model."
                    )

609
610
611
612
        # add missing keys as random parameters
        for missing_key in missing_keys:
            state[missing_key] = random_state[missing_key]

613
614
615
616
        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
635
        elif len(mismatched_keys) == 0:
636
637
638
639
640
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )
641
642
643
644
645
646
647
648
649
650
651
652
        if len(mismatched_keys) > 0:
            mismatched_warning = "\n".join(
                [
                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                    for key, shape1, shape2 in mismatched_keys
                ]
            )
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
653
654
655

        # set correct parameters
        model.params = unflatten_dict(state)
656

657
658
        return model

659
    def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
660
661
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
662
        `[`~FlaxPreTrainedModel.from_pretrained`]` class method
663
664

        Arguments:
665
            save_directory (`str` or `os.PathLike`):
666
                Directory to which to save. Will be created if it doesn't exist.
667
            push_to_hub (`bool`, *optional*, defaults to `False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
668
                Whether or not to push your model to the Hugging Face model hub after saving it.
669

670
671
                <Tip warning={true}>

Sylvain Gugger's avatar
Sylvain Gugger committed
672
673
674
                Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
                which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
                folder. Pass along `temp_dir=True` to use a temporary directory instead.
675

676
                </Tip>
677

Sylvain Gugger's avatar
Sylvain Gugger committed
678
            kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
679
                Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
680
681
        """
        if os.path.isfile(save_directory):
682
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
683
            return
684
685
686
687
688

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

689
690
691
692
693
        os.makedirs(save_directory, exist_ok=True)

        # get abs dir
        save_directory = os.path.abspath(save_directory)
        # save config as well
694
        self.config.architectures = [self.__class__.__name__[4:]]
695
696
697
698
699
700

        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
        # loaded from the Hub.
        if self._auto_class is not None:
            custom_object_save(self, save_directory, config=self.config)

701
702
703
        self.config.save_pretrained(save_directory)

        # save model
Sylvain Gugger's avatar
Sylvain Gugger committed
704
705
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        with open(output_model_file, "wb") as f:
706
707
            params = params if params is not None else self.params
            model_bytes = to_bytes(params)
708
            f.write(model_bytes)
709

Sylvain Gugger's avatar
Sylvain Gugger committed
710
711
712
        logger.info(f"Model weights saved in {output_model_file}")

        if push_to_hub:
713
            url = self._push_to_hub(repo, commit_message=commit_message)
Sylvain Gugger's avatar
Sylvain Gugger committed
714
715
            logger.info(f"Model pushed to the hub in this commit: {url}")

716
717
718
719
720
721
    @classmethod
    def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
        """
        Register this class with a given auto class. This should only be used for custom models as the ones in the
        library are already mapped with an auto class.

722
723
724
725
726
727
        <Tip warning={true}>

        This API is experimental and may have some slight breaking changes in the next releases.

        </Tip>

728
729
730
731
732
733
734
735
736
737
738
739
740
741
        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
                The auto class to register this new model with.
        """
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        import transformers.models.auto as auto_module

        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        cls._auto_class = auto_class

742

743
744
745
746
747
748
749
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
    object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
)


750
751
752
753
754
755
756
def overwrite_call_docstring(model_class, docstring):
    # copy __call__ function to be sure docstring is changed only for this function
    model_class.__call__ = copy_func(model_class.__call__)
    # delete existing docstring
    model_class.__call__.__doc__ = None
    # set correct docstring
    model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
757
758
759
760
761


def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
    model_class.__call__ = copy_func(model_class.__call__)
    model_class.__call__ = add_code_sample_docstrings(
762
        processor_class=tokenizer_class,
763
764
765
766
767
768
769
770
771
772
773
774
775
        checkpoint=checkpoint,
        output_type=output_type,
        config_class=config_class,
        model_cls=model_class.__name__,
    )(model_class.__call__)


def append_replace_return_docstrings(model_class, output_type, config_class):
    model_class.__call__ = copy_func(model_class.__call__)
    model_class.__call__ = replace_return_docstrings(
        output_type=output_type,
        config_class=config_class,
    )(model_class.__call__)