modeling_flax_utils.py 26.6 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2025 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pickle import UnpicklingError
from typing import Any, Dict, Union

import jax
import jax.numpy as jnp
import msgpack.exceptions
23
from flax.core.frozen_dict import FrozenDict, unfreeze
24
25
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
26
from huggingface_hub import create_repo, hf_hub_download
27
28
from huggingface_hub.utils import (
    EntryNotFoundError,
29
    HfHubHTTPError,
30
31
32
33
    RepositoryNotFoundError,
    RevisionNotFoundError,
    validate_hf_hub_args,
)
34

35
36
from .. import __version__, is_torch_available
from ..utils import (
37
38
39
40
    CONFIG_NAME,
    FLAX_WEIGHTS_NAME,
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
    WEIGHTS_NAME,
41
    PushToHubMixin,
42
43
    logging,
)
44
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
45
46
47
48
49


logger = logging.get_logger(__name__)


50
class FlaxModelMixin(PushToHubMixin):
51
    r"""
Steven Liu's avatar
Steven Liu committed
52
    Base class for all Flax models.
53

Steven Liu's avatar
Steven Liu committed
54
55
56
57
    [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
    saving models.

        - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
58
    """
59

60
61
    config_name = CONFIG_NAME
    _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
62
    _flax_internal_args = ["name", "parent", "dtype"]
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.
        """
        return cls(config, **kwargs)

    def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
        """
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
        """

        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        if mask is None:
            return jax.tree_map(conditional_cast, params)

        flat_params = flatten_dict(params)
        flat_mask, _ = jax.tree_flatten(mask)

        for masked, key in zip(flat_mask, flat_params.keys()):
            if masked:
                param = flat_params[key]
                flat_params[key] = conditional_cast(param)

        return unflatten_dict(flat_params)

    def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
        the `params` in place.

Steven Liu's avatar
Steven Liu committed
100
        This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
101
102
103
104
105
106
        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Steven Liu's avatar
Steven Liu committed
107
108
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
109
110
111
112
113
114
115

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # load model
116
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
117
118
119
120
121
122
        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
        >>> params = model.to_bf16(params)
        >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util

123
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        >>> flat_params = traverse_util.flatten_dict(params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> params = model.to_bf16(params, mask)
        ```"""
        return self._cast_floating_to(params, jnp.bfloat16, mask)

    def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
        model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Steven Liu's avatar
Steven Liu committed
143
144
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
145
146
147
148
149
150
151

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # Download model and configuration from huggingface.co
152
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
153
154
155
156
157
158
159
160
161
162
163
164
165
        >>> # By default, the model params will be in fp32, to illustrate the use of this method,
        >>> # we'll first cast to fp16 and back to fp32
        >>> params = model.to_f16(params)
        >>> # now cast back to fp32
        >>> params = model.to_fp32(params)
        ```"""
        return self._cast_floating_to(params, jnp.float32, mask)

    def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
        `params` in place.

Steven Liu's avatar
Steven Liu committed
166
        This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
167
168
169
170
171
172
        half-precision training or to save weights in float16 for inference in order to save memory and improve speed.

        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
Steven Liu's avatar
Steven Liu committed
173
174
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
175
176
177
178
179
180
181

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # load model
182
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
183
184
185
186
187
188
        >>> # By default, the model params will be in fp32, to cast these to float16
        >>> params = model.to_fp16(params)
        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util

189
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
190
191
192
193
194
195
196
197
198
199
        >>> flat_params = traverse_util.flatten_dict(params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> params = model.to_fp16(params, mask)
        ```"""
        return self._cast_floating_to(params, jnp.float16, mask)

200
    def init_weights(self, rng: jax.Array) -> Dict:
201
202
        raise NotImplementedError(f"init_weights method has to be implemented for {self}")

203
    @classmethod
204
    @validate_hf_hub_args
205
206
207
208
209
210
211
212
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        dtype: jnp.dtype = jnp.float32,
        *model_args,
        **kwargs,
    ):
        r"""
Steven Liu's avatar
Steven Liu committed
213
        Instantiate a pretrained Flax model from a pretrained model configuration.
214
215
216
217
218

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                Can be either:

219
220
                    - A string, the *model id* (for example `stable-diffusion-v1-5/stable-diffusion-v1-5`) of a
                      pretrained model hosted on the Hub.
Steven Liu's avatar
Steven Liu committed
221
222
                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
                      using [`~FlaxModelMixin.save_pretrained`].
223
224
225
226
227
            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
                `jax.numpy.bfloat16` (on TPUs).

                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
Steven Liu's avatar
Steven Liu committed
228
229
                specified, all the computation will be performed with the given `dtype`.

Steven Liu's avatar
Steven Liu committed
230
231
232
                > [!TIP] > This only specifies the dtype of the *computation* and does not influence the dtype of model
                > parameters. > > If you wish to change the dtype of the model parameters, see
                [`~FlaxModelMixin.to_fp16`] and > [`~FlaxModelMixin.to_bf16`].
233
234

            model_args (sequence of positional arguments, *optional*):
Steven Liu's avatar
Steven Liu committed
235
                All remaining positional arguments are passed to the underlying model's `__init__` method.
236
            cache_dir (`Union[str, os.PathLike]`, *optional*):
Steven Liu's avatar
Steven Liu committed
237
238
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
239
240
241
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
242

243
            proxies (`Dict[str, str]`, *optional*):
Steven Liu's avatar
Steven Liu committed
244
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
245
246
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(`bool`, *optional*, defaults to `False`):
Steven Liu's avatar
Steven Liu committed
247
248
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
249
            revision (`str`, *optional*, defaults to `"main"`):
Steven Liu's avatar
Steven Liu committed
250
251
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
252
253
            from_pt (`bool`, *optional*, defaults to `False`):
                Load the model weights from a PyTorch checkpoint save file.
254
            kwargs (remaining dictionary of keyword arguments, *optional*):
Steven Liu's avatar
Steven Liu committed
255
256
                Can be used to update the configuration object (after it is loaded) and initiate the model (for
                example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
257
258
                automatically loaded:

Steven Liu's avatar
Steven Liu committed
259
260
261
262
263
264
265
266
                    - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
                      model's `__init__` method (we assume all relevant updates to the configuration have already been
                      done).
                    - If a configuration is not provided, `kwargs` are first passed to the configuration class
                      initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
                      to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
                      Remaining keys that do not correspond to any configuration attribute are passed to the underlying
                      model's `__init__` function.
267
268
269
270
271
272
273

        Examples:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # Download model and configuration from huggingface.co and cache.
274
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
275
276
        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
Steven Liu's avatar
Steven Liu committed
277
278
279
280
281
        ```

        If you get the error message below, you need to finetune the weights for your downstream task:

        ```bash
282
        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
Steven Liu's avatar
Steven Liu committed
283
284
285
286
        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
        ```
        """
Sayak Paul's avatar
Sayak Paul committed
287
288
289
290
        logger.warning(
            "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
            "recommend migrating to PyTorch classes or pinning your version of Diffusers."
        )
291
        config = kwargs.pop("config", None)
292
        cache_dir = kwargs.pop("cache_dir", None)
293
        force_download = kwargs.pop("force_download", False)
294
        from_pt = kwargs.pop("from_pt", False)
295
296
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
297
        token = kwargs.pop("token", None)
298
299
300
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)

301
302
303
304
305
        user_agent = {
            "diffusers": __version__,
            "file_type": "model",
            "framework": "flax",
        }
306

Pedro Cuenca's avatar
Pedro Cuenca committed
307
308
309
310
311
312
313
314
315
        # Load config if we don't provide one
        if config is None:
            config, unused_kwargs = cls.load_config(
                pretrained_model_name_or_path,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
316
                token=token,
Pedro Cuenca's avatar
Pedro Cuenca committed
317
318
319
320
321
322
                revision=revision,
                subfolder=subfolder,
                **kwargs,
            )

        model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
323
324

        # Load model
325
326
327
328
329
330
        pretrained_path_with_subfolder = (
            pretrained_model_name_or_path
            if subfolder is None
            else os.path.join(pretrained_model_name_or_path, subfolder)
        )
        if os.path.isdir(pretrained_path_with_subfolder):
331
            if from_pt:
332
                if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
333
                    raise EnvironmentError(
334
                        f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
335
                    )
336
337
                model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
            elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
338
                # Load from a Flax checkpoint
339
                model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
340
            # Check if pytorch weights exist instead
341
            elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
342
                raise EnvironmentError(
343
                    f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
Pedro Cuenca's avatar
Pedro Cuenca committed
344
                    " using `from_pt=True`."
345
346
347
348
                )
            else:
                raise EnvironmentError(
                    f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
349
                    f"{pretrained_path_with_subfolder}."
350
351
352
353
354
                )
        else:
            try:
                model_file = hf_hub_download(
                    pretrained_model_name_or_path,
355
                    filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
356
357
358
359
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
360
                    token=token,
361
362
363
364
365
366
367
368
369
                    user_agent=user_agent,
                    subfolder=subfolder,
                    revision=revision,
                )

            except RepositoryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
                    "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
370
                    "token having permission to this repo with `token` or log in with `hf auth login`."
371
372
373
374
375
376
377
378
379
380
381
                )
            except RevisionNotFoundError:
                raise EnvironmentError(
                    f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
                    "this model name. Check the model page at "
                    f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
                )
            except EntryNotFoundError:
                raise EnvironmentError(
                    f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
                )
382
            except HfHubHTTPError as err:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                raise EnvironmentError(
                    f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
                    f"{err}"
                )
            except ValueError:
                raise EnvironmentError(
                    f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
                    f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
                    f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
                    " internet connection or see how to run the library in offline mode at"
                    " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
                )
            except EnvironmentError:
                raise EnvironmentError(
                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
                    "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
                    f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
                    f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
                )

403
        if from_pt:
404
405
406
407
408
409
410
411
            if is_torch_available():
                from .modeling_utils import load_state_dict
            else:
                raise EnvironmentError(
                    "Can't load the model in PyTorch format because PyTorch is not installed. "
                    "Please, install PyTorch or use native Flax weights."
                )

412
413
414
415
416
417
            # Step 1: Get the pytorch file
            pytorch_model_file = load_state_dict(model_file)

            # Step 2: Convert the weights
            state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
        else:
418
            try:
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                with open(model_file, "rb") as state_f:
                    state = from_bytes(cls, state_f.read())
            except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
                try:
                    with open(model_file) as f:
                        if f.read().startswith("version"):
                            raise OSError(
                                "You seem to have cloned a repository without having git-lfs installed. Please"
                                " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
                                " folder you cloned."
                            )
                        else:
                            raise ValueError from e
                except (UnicodeDecodeError, ValueError):
                    raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
            # make sure all arrays are stored as jnp.ndarray
            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
            # https://github.com/google/flax/issues/1261
437
        state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
438
439
440
441

        # flatten dicts
        state = flatten_dict(state)

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
        required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())

        shape_state = flatten_dict(unfreeze(params_shape_tree))

        missing_keys = required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - required_params

        if missing_keys:
            logger.warning(
                f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
                "Make sure to call model.init_weights to initialize the missing weights."
            )
            cls._missing_keys = missing_keys

        for key in state.keys():
            if key in shape_state and state[key].shape != shape_state[key].shape:
                raise ValueError(
                    f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
                    f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
                )

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
                " with another architecture."
            )
        else:
            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
484
        else:
485
486
487
488
489
490
491
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
                " training."
            )

492
493
494
495
496
497
498
        return model, unflatten_dict(state)

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        params: Union[Dict, FrozenDict],
        is_main_process: bool = True,
499
500
        push_to_hub: bool = False,
        **kwargs,
501
502
    ):
        """
Steven Liu's avatar
Steven Liu committed
503
504
        Save a model and its configuration file to a directory so that it can be reloaded using the
        [`~FlaxModelMixin.from_pretrained`] class method.
505
506
507

        Arguments:
            save_directory (`str` or `os.PathLike`):
Steven Liu's avatar
Steven Liu committed
508
                Directory to save a model and its configuration file to. Will be created if it doesn't exist.
509
510
511
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            is_main_process (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
512
513
514
                Whether the process calling this is the main process or not. Useful during distributed training and you
                need to call this function on all processes. In this case, set `is_main_process=True` only on the main
                process to avoid race conditions.
515
516
517
518
519
520
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
521
522
523
524
525
526
527
        """
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        os.makedirs(save_directory, exist_ok=True)

528
529
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
530
            private = kwargs.pop("private", None)
531
532
533
534
535
            create_pr = kwargs.pop("create_pr", False)
            token = kwargs.pop("token", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

536
537
538
539
540
541
542
543
544
545
546
547
548
549
        model_to_save = self

        # Attach architecture to the config
        # Save the config
        if is_main_process:
            model_to_save.save_config(save_directory)

        # save model
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        with open(output_model_file, "wb") as f:
            model_bytes = to_bytes(params)
            f.write(model_bytes)

        logger.info(f"Model weights saved in {output_model_file}")
550
551
552
553
554
555
556
557
558

        if push_to_hub:
            self._upload_folder(
                save_directory,
                repo_id,
                token=token,
                commit_message=commit_message,
                create_pr=create_pr,
            )