modeling_flax_utils.py 26.7 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pickle import UnpicklingError
from typing import Any, Dict, Union

import jax
import jax.numpy as jnp
import msgpack.exceptions
23
from flax.core.frozen_dict import FrozenDict, unfreeze
24
25
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
26
from huggingface_hub import create_repo, hf_hub_download
27
28
29
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError

30
31
from .. import __version__, is_torch_available
from ..utils import (
32
33
34
35
36
    CONFIG_NAME,
    DIFFUSERS_CACHE,
    FLAX_WEIGHTS_NAME,
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
    WEIGHTS_NAME,
37
    PushToHubMixin,
38
39
    logging,
)
40
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
41
42
43
44
45


logger = logging.get_logger(__name__)


46
class FlaxModelMixin(PushToHubMixin):
47
    r"""
Steven Liu's avatar
Steven Liu committed
48
    Base class for all Flax models.
49

Steven Liu's avatar
Steven Liu committed
50
51
52
53
    [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
    saving models.

        - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
54
    """
55

56
57
    config_name = CONFIG_NAME
    _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
58
    _flax_internal_args = ["name", "parent", "dtype"]
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

    def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
        """
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
        """

        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        if mask is None:
            return jax.tree_map(conditional_cast, params)

        flat_params = flatten_dict(params)
        flat_mask, _ = jax.tree_flatten(mask)

        for masked, key in zip(flat_mask, flat_params.keys()):
            if masked:
                param = flat_params[key]
                flat_params[key] = conditional_cast(param)

        return unflatten_dict(flat_params)

    def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
        the `params` in place.

Steven Liu's avatar
Steven Liu committed
96
        This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
97
98
99
100
101
102
        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Steven Liu's avatar
Steven Liu committed
103
104
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
105
106
107
108
109
110
111

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # load model
apolinario's avatar
apolinario committed
112
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
113
114
115
116
117
118
        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
        >>> params = model.to_bf16(params)
        >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util

apolinario's avatar
apolinario committed
119
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        >>> flat_params = traverse_util.flatten_dict(params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> params = model.to_bf16(params, mask)
        ```"""
        return self._cast_floating_to(params, jnp.bfloat16, mask)

    def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
        model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Steven Liu's avatar
Steven Liu committed
139
140
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
141
142
143
144
145
146
147

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # Download model and configuration from huggingface.co
apolinario's avatar
apolinario committed
148
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
149
150
151
152
153
154
155
156
157
158
159
160
161
        >>> # By default, the model params will be in fp32, to illustrate the use of this method,
        >>> # we'll first cast to fp16 and back to fp32
        >>> params = model.to_f16(params)
        >>> # now cast back to fp32
        >>> params = model.to_fp32(params)
        ```"""
        return self._cast_floating_to(params, jnp.float32, mask)

    def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
        `params` in place.

Steven Liu's avatar
Steven Liu committed
162
        This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
163
164
165
166
167
168
        half-precision training or to save weights in float16 for inference in order to save memory and improve speed.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Steven Liu's avatar
Steven Liu committed
169
170
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
171
172
173
174
175
176
177

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # load model
apolinario's avatar
apolinario committed
178
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
179
180
181
182
183
184
        >>> # By default, the model params will be in fp32, to cast these to float16
        >>> params = model.to_fp16(params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util

apolinario's avatar
apolinario committed
185
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
186
187
188
189
190
191
192
193
194
195
        >>> flat_params = traverse_util.flatten_dict(params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> params = model.to_fp16(params, mask)
        ```"""
        return self._cast_floating_to(params, jnp.float16, mask)

196
    def init_weights(self, rng: jax.Array) -> Dict:
197
198
        raise NotImplementedError(f"init_weights method has to be implemented for {self}")

199
200
201
202
203
204
205
206
207
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        dtype: jnp.dtype = jnp.float32,
        *model_args,
        **kwargs,
    ):
        r"""
Steven Liu's avatar
Steven Liu committed
208
        Instantiate a pretrained Flax model from a pretrained model configuration.
209
210
211
212
213

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                Can be either:

Steven Liu's avatar
Steven Liu committed
214
215
216
217
                    - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
                      hosted on the Hub.
                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
                      using [`~FlaxModelMixin.save_pretrained`].
218
219
220
221
222
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs).

                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
Steven Liu's avatar
Steven Liu committed
223
224
225
226
227
228
                specified, all the computation will be performed with the given `dtype`.

                <Tip>

                This only specifies the dtype of the *computation* and does not influence the dtype of model
                parameters.
229

Steven Liu's avatar
Steven Liu committed
230
231
232
233
                If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
                [`~FlaxModelMixin.to_bf16`].

                </Tip>
234
235

            model_args (sequence of positional arguments, *optional*):
Steven Liu's avatar
Steven Liu committed
236
                All remaining positional arguments are passed to the underlying model's `__init__` method.
237
            cache_dir (`Union[str, os.PathLike]`, *optional*):
Steven Liu's avatar
Steven Liu committed
238
239
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
240
241
242
243
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
Steven Liu's avatar
Steven Liu committed
244
245
                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
                incompletely downloaded files are deleted.
246
            proxies (`Dict[str, str]`, *optional*):
Steven Liu's avatar
Steven Liu committed
247
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
248
249
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(`bool`, *optional*, defaults to `False`):
Steven Liu's avatar
Steven Liu committed
250
251
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
252
            revision (`str`, *optional*, defaults to `"main"`):
Steven Liu's avatar
Steven Liu committed
253
254
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
255
256
            from_pt (`bool`, *optional*, defaults to `False`):
                Load the model weights from a PyTorch checkpoint save file.
257
            kwargs (remaining dictionary of keyword arguments, *optional*):
Steven Liu's avatar
Steven Liu committed
258
259
                Can be used to update the configuration object (after it is loaded) and initiate the model (for
                example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
260
261
                automatically loaded:

Steven Liu's avatar
Steven Liu committed
262
263
264
265
266
267
268
269
                    - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
                      model's `__init__` method (we assume all relevant updates to the configuration have already been
                      done).
                    - If a configuration is not provided, `kwargs` are first passed to the configuration class
                      initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
                      to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
                      Remaining keys that do not correspond to any configuration attribute are passed to the underlying
                      model's `__init__` function.
270
271
272
273
274
275
276

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # Download model and configuration from huggingface.co and cache.
apolinario's avatar
apolinario committed
277
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
278
279
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
Steven Liu's avatar
Steven Liu committed
280
281
282
283
284
285
286
287
288
289
        ```

        If you get the error message below, you need to finetune the weights for your downstream task:

        ```bash
        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
        ```
        """
290
291
292
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
        force_download = kwargs.pop("force_download", False)
293
        from_pt = kwargs.pop("from_pt", False)
294
295
296
297
298
299
300
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)

301
302
303
304
305
        user_agent = {
            "diffusers": __version__,
            "file_type": "model",
            "framework": "flax",
        }
306

Pedro Cuenca's avatar
Pedro Cuenca committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        # Load config if we don't provide one
        if config is None:
            config, unused_kwargs = cls.load_config(
                pretrained_model_name_or_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                subfolder=subfolder,
                **kwargs,
            )

        model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
324
325

        # Load model
326
327
328
329
330
331
        pretrained_path_with_subfolder = (
            pretrained_model_name_or_path
            if subfolder is None
            else os.path.join(pretrained_model_name_or_path, subfolder)
        )
        if os.path.isdir(pretrained_path_with_subfolder):
332
            if from_pt:
333
                if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
334
                    raise EnvironmentError(
335
                        f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
336
                    )
337
338
                model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
            elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
339
                # Load from a Flax checkpoint
340
                model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
341
            # Check if pytorch weights exist instead
342
            elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
343
                raise EnvironmentError(
344
                    f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
Pedro Cuenca's avatar
Pedro Cuenca committed
345
                    " using `from_pt=True`."
346
347
348
349
                )
            else:
                raise EnvironmentError(
                    f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
350
                    f"{pretrained_path_with_subfolder}."
351
352
353
354
355
                )
        else:
            try:
                model_file = hf_hub_download(
                    pretrained_model_name_or_path,
356
                    filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                    subfolder=subfolder,
                    revision=revision,
                )

            except RepositoryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
                    "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
                    "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
373
                    "login`."
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                )
            except RevisionNotFoundError:
                raise EnvironmentError(
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
                    "this model name. Check the model page at "
                    f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
                )
            except EntryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
                )
            except HTTPError as err:
                raise EnvironmentError(
                    f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
                    f"{err}"
                )
            except ValueError:
                raise EnvironmentError(
                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
                    f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
                    f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
                    " internet connection or see how to run the library in offline mode at"
                    " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
                )
            except EnvironmentError:
                raise EnvironmentError(
                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
                )

406
        if from_pt:
407
408
409
410
411
412
413
414
            if is_torch_available():
                from .modeling_utils import load_state_dict
            else:
                raise EnvironmentError(
                    "Can't load the model in PyTorch format because PyTorch is not installed. "
                    "Please, install PyTorch or use native Flax weights."
                )

415
416
417
418
419
420
            # Step 1: Get the pytorch file
            pytorch_model_file = load_state_dict(model_file)

            # Step 2: Convert the weights
            state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
        else:
421
            try:
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
                with open(model_file, "rb") as state_f:
                    state = from_bytes(cls, state_f.read())
            except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                try:
                    with open(model_file) as f:
                        if f.read().startswith("version"):
                            raise OSError(
                                "You seem to have cloned a repository without having git-lfs installed. Please"
                                " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
                                " folder you cloned."
                            )
                        else:
                            raise ValueError from e
                except (UnicodeDecodeError, ValueError):
                    raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
            # make sure all arrays are stored as jnp.ndarray
            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
            # https://github.com/google/flax/issues/1261
440
441
442
443
444
        state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)

        # flatten dicts
        state = flatten_dict(state)

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
        required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())

        shape_state = flatten_dict(unfreeze(params_shape_tree))

        missing_keys = required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - required_params

        if missing_keys:
            logger.warning(
                f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
                "Make sure to call model.init_weights to initialize the missing weights."
            )
            cls._missing_keys = missing_keys

        for key in state.keys():
            if key in shape_state and state[key].shape != shape_state[key].shape:
                raise ValueError(
                    f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
                    f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
                )

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
                " with another architecture."
            )
        else:
            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
487
        else:
488
489
490
491
492
493
494
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
                " training."
            )

495
496
497
498
499
500
501
        return model, unflatten_dict(state)

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        params: Union[Dict, FrozenDict],
        is_main_process: bool = True,
502
503
        push_to_hub: bool = False,
        **kwargs,
504
505
    ):
        """
Steven Liu's avatar
Steven Liu committed
506
507
        Save a model and its configuration file to a directory so that it can be reloaded using the
        [`~FlaxModelMixin.from_pretrained`] class method.
508
509
510

        Arguments:
            save_directory (`str` or `os.PathLike`):
Steven Liu's avatar
Steven Liu committed
511
                Directory to save a model and its configuration file to. Will be created if it doesn't exist.
512
513
514
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            is_main_process (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
515
516
517
                Whether the process calling this is the main process or not. Useful during distributed training and you
                need to call this function on all processes. In this case, set `is_main_process=True` only on the main
                process to avoid race conditions.
518
519
520
521
522
523
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
524
525
526
527
528
529
530
        """
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        os.makedirs(save_directory, exist_ok=True)

531
532
533
534
535
536
537
538
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            private = kwargs.pop("private", False)
            create_pr = kwargs.pop("create_pr", False)
            token = kwargs.pop("token", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

539
540
541
542
543
544
545
546
547
548
549
550
551
552
        model_to_save = self

        # Attach architecture to the config
        # Save the config
        if is_main_process:
            model_to_save.save_config(save_directory)

        # save model
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        with open(output_model_file, "wb") as f:
            model_bytes = to_bytes(params)
            f.write(model_bytes)

        logger.info(f"Model weights saved in {output_model_file}")
553
554
555
556
557
558
559
560
561

        if push_to_hub:
            self._upload_folder(
                save_directory,
                repo_id,
                token=token,
                commit_message=commit_message,
                create_pr=create_pr,
            )