pipeline_flax_utils.py 26.8 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import inspect
import os
20
from typing import Any, Dict, List, Optional, Union
21
22

import flax
23
import numpy as np
Anh71me's avatar
Anh71me committed
24
import PIL.Image
25
from flax.core.frozen_dict import FrozenDict
26
from huggingface_hub import create_repo, snapshot_download
27
28
29
from PIL import Image
from tqdm.auto import tqdm

30
31
32
from ..configuration_utils import ConfigMixin
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
33
34
35
36
37
38
39
40
41
from ..utils import (
    CONFIG_NAME,
    DIFFUSERS_CACHE,
    BaseOutput,
    PushToHubMixin,
    http_user_agent,
    is_transformers_available,
    logging,
)
42
43
44
45
46
47
48
49
50
51
52
53
54
55


if is_transformers_available():
    from transformers import FlaxPreTrainedModel

INDEX_FILE = "diffusion_flax_model.bin"


logger = logging.get_logger(__name__)


LOADABLE_CLASSES = {
    "diffusers": {
        "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
56
        "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
57
58
59
60
61
62
63
        "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
    },
    "transformers": {
        "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
        "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
        "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
        "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
64
65
        "ProcessorMixin": ["save_pretrained", "from_pretrained"],
        "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    },
}

ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
    ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])


def import_flax_or_no_model(module, class_name):
    try:
        # 1. First make sure that if a Flax object is present, import this one
        class_obj = getattr(module, "Flax" + class_name)
    except AttributeError:
        # 2. If this doesn't work, it's not a model and we don't append "Flax"
        class_obj = getattr(module, class_name)
    except AttributeError:
        raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")

    return class_obj


@flax.struct.dataclass
class FlaxImagePipelineOutput(BaseOutput):
    """
    Output class for image pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
Steven Liu's avatar
Steven Liu committed
94
95
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
96
97
98
99
100
    """

    images: Union[List[PIL.Image.Image], np.ndarray]


101
class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
102
    r"""
103
    Base class for Flax-based pipelines.
104

105
106
    [`FlaxDiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and
    provides methods for loading, downloading and saving models. It also includes methods to:
107

108
        - enable/disable the progress bar for the denoising iteration
109
110
111

    Class attributes:

112
113
        - **config_name** ([`str`]) -- The configuration filename that stores the class and module names of all the
          diffusion pipeline's components.
114
    """
115

116
117
118
119
120
121
122
    config_name = "model_index.json"

    def register_modules(self, **kwargs):
        # import it here to avoid circular import
        from diffusers import pipelines

        for name, module in kwargs.items():
123
124
125
126
127
            if module is None:
                register_dict = {name: (None, None)}
            else:
                # retrieve library
                library = module.__module__.split(".")[0]
128

129
130
131
132
                # check if the module is a pipeline module
                pipeline_dir = module.__module__.split(".")[-2]
                path = module.__module__.split(".")
                is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
133

134
135
136
137
138
                # if library is not in LOADABLE_CLASSES, then it is a custom module.
                # Or if it's a pipeline module, then the module is inside the pipeline
                # folder so we set the library to module name.
                if library not in LOADABLE_CLASSES or is_pipeline_module:
                    library = pipeline_dir
139

140
141
                # retrieve class_name
                class_name = module.__class__.__name__
142

143
                register_dict = {name: (library, class_name)}
144
145
146
147
148
149
150

            # save model index config
            self.register_to_config(**register_dict)

            # set models
            setattr(self, name, module)

151
152
153
154
155
156
157
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        params: Union[Dict, FrozenDict],
        push_to_hub: bool = False,
        **kwargs,
    ):
158
159
        # TODO: handle inference_state
        """
160
161
162
        Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
        class implements both a save and loading method. The pipeline is easily reloaded using the
        [`~FlaxDiffusionPipeline.from_pretrained`] class method.
163
164
165
166

        Arguments:
            save_directory (`str` or `os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
167
168
169
170
171
172
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
173
174
175
176
177
178
179
180
        """
        self.save_config(save_directory)

        model_index_dict = dict(self.config)
        model_index_dict.pop("_class_name")
        model_index_dict.pop("_diffusers_version")
        model_index_dict.pop("_module", None)

181
182
183
184
185
186
187
188
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            private = kwargs.pop("private", False)
            create_pr = kwargs.pop("create_pr", False)
            token = kwargs.pop("token", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

189
190
        for pipeline_component_name in model_index_dict.keys():
            sub_model = getattr(self, pipeline_component_name)
191
192
193
194
            if sub_model is None:
                # edge case for saving a pipeline with safety_checker=None
                continue

195
196
197
198
199
200
201
            model_cls = sub_model.__class__

            save_method_name = None
            # search for the model's base class in LOADABLE_CLASSES
            for library_name, library_classes in LOADABLE_CLASSES.items():
                library = importlib.import_module(library_name)
                for base_class, save_load_methods in library_classes.items():
202
203
                    class_candidate = getattr(library, base_class, None)
                    if class_candidate is not None and issubclass(model_cls, class_candidate):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                        # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
                        save_method_name = save_load_methods[0]
                        break
                if save_method_name is not None:
                    break

            save_method = getattr(sub_model, save_method_name)
            expects_params = "params" in set(inspect.signature(save_method).parameters.keys())

            if expects_params:
                save_method(
                    os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
                )
            else:
                save_method(os.path.join(save_directory, pipeline_component_name))

220
221
222
223
224
225
226
227
228
            if push_to_hub:
                self._upload_folder(
                    save_directory,
                    repo_id,
                    token=token,
                    commit_message=commit_message,
                    create_pr=create_pr,
                )

229
230
231
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
        r"""
232
        Instantiate a Flax-based diffusion pipeline from pretrained pipeline weights.
233

234
        The pipeline is set in evaluation mode (`model.eval()) by default and dropout modules are deactivated.
235

236
        If you get the error message below, you need to finetune the weights for your downstream task:
237

238
239
240
        ```
        Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
        ```
241
242
243
244
245

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

246
247
248
249
                    - A string, the *repo id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained pipeline
                      hosted on the Hub.
                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
                      using [`~FlaxDiffusionPipeline.save_pretrained`].
250
            dtype (`str` or `jnp.dtype`, *optional*):
251
252
                Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is
                automatically derived from the model's weights.
253
254
255
256
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
257
258
                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
                incompletely downloaded files are deleted.
259
            proxies (`Dict[str, str]`, *optional*):
260
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
261
262
263
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
264
265
266
            local_files_only (`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
267
            use_auth_token (`str` or *bool*, *optional*):
268
269
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
270
            revision (`str`, *optional*, defaults to `"main"`):
271
272
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
273
            mirror (`str`, *optional*):
274
275
276
                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
                information.
277
            kwargs (remaining dictionary of keyword arguments, *optional*):
278
279
                Can be used to overwrite load and saveable variables (the pipeline components) of the specific pipeline
                class. The overwritten components are passed directly to the pipelines `__init__` method.
280
281
282

        <Tip>

283
        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
Steven Liu's avatar
Steven Liu committed
284
        `huggingface-cli login`.
285
286
287
288
289
290
291
292
293

        </Tip>

        Examples:

        ```py
        >>> from diffusers import FlaxDiffusionPipeline

        >>> # Download pipeline from huggingface.co and cache.
Pedro Cuenca's avatar
Pedro Cuenca committed
294
295
296
297
298
299
300
301
302
303
304
305
        >>> # Requires to be logged in to Hugging Face hub,
        >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
        >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
        ...     "runwayml/stable-diffusion-v1-5",
        ...     revision="bf16",
        ...     dtype=jnp.bfloat16,
        ... )

        >>> # Download pipeline, but use a different scheduler
        >>> from diffusers import FlaxDPMSolverMultistepScheduler

        >>> model_id = "runwayml/stable-diffusion-v1-5"
YiYi Xu's avatar
YiYi Xu committed
306
        >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
Pedro Cuenca's avatar
Pedro Cuenca committed
307
308
309
310
311
312
313
314
        ...     model_id,
        ...     subfolder="scheduler",
        ... )

        >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
        ...     model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
        ... )
        >>> dpm_params["scheduler"] = dpmpp_state
315
316
317
318
319
320
321
322
323
        ```
        """
        cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        from_pt = kwargs.pop("from_pt", False)
324
        use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
325
        split_head_dim = kwargs.pop("split_head_dim", False)
326
327
328
329
330
        dtype = kwargs.pop("dtype", None)

        # 1. Download the checkpoints and configs
        # use snapshot download here to get it working from from_pretrained
        if not os.path.isdir(pretrained_model_name_or_path):
331
            config_dict = cls.load_config(
332
333
334
335
336
337
338
339
340
341
342
343
344
                pretrained_model_name_or_path,
                cache_dir=cache_dir,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
            )
            # make sure we only download sub-folders and `diffusers` filenames
            folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
            allow_patterns = [os.path.join(k, "*") for k in folder_names]
            allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]

345
346
            ignore_patterns = ["*.bin", "*.safetensors"] if not from_pt else []
            ignore_patterns += ["*.onnx", "*.onnx_data", "*.xml", "*.pb"]
347

348
349
350
351
            if cls != FlaxDiffusionPipeline:
                requested_pipeline_class = cls.__name__
            else:
                requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
352
353
354
355
356
357
                requested_pipeline_class = (
                    requested_pipeline_class
                    if requested_pipeline_class.startswith("Flax")
                    else "Flax" + requested_pipeline_class
                )

358
359
360
            user_agent = {"pipeline_class": requested_pipeline_class}
            user_agent = http_user_agent(user_agent)

361
362
363
364
365
366
367
368
369
370
            # download all allow_patterns
            cached_folder = snapshot_download(
                pretrained_model_name_or_path,
                cache_dir=cache_dir,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                allow_patterns=allow_patterns,
371
                ignore_patterns=ignore_patterns,
372
                user_agent=user_agent,
373
374
375
376
            )
        else:
            cached_folder = pretrained_model_name_or_path

377
        config_dict = cls.load_config(cached_folder)
378
379
380
381
382
383
384

        # 2. Load the pipeline class, if using custom module then load it from the hub
        # if we load from explicit class, let's use it
        if cls != FlaxDiffusionPipeline:
            pipeline_class = cls
        else:
            diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
385
386
387
388
389
            class_name = (
                config_dict["_class_name"]
                if config_dict["_class_name"].startswith("Flax")
                else "Flax" + config_dict["_class_name"]
            )
390
            pipeline_class = getattr(diffusers_module, class_name)
391
392
393
394

        # some modules can be passed directly to the init
        # in this case they are already instantiated in `kwargs`
        # extract them here
YiYi Xu's avatar
YiYi Xu committed
395
        expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
396
        passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
Pedro Cuenca's avatar
Pedro Cuenca committed
397
        passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
398

Pedro Cuenca's avatar
Pedro Cuenca committed
399
        init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
400

Pedro Cuenca's avatar
Pedro Cuenca committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        # define init kwargs
        init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
        init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

        # remove `null` components
        def load_module(name, value):
            if value[0] is None:
                return False
            if name in passed_class_obj and passed_class_obj[name] is None:
                return False
            return True

        init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

        # Throw nice warnings / errors for fast accelerate loading
        if len(unused_kwargs) > 0:
            logger.warning(
                f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
            )
420
421
422
423
424
425
426
427
428

        # inference_params
        params = {}

        # import it here to avoid circular import
        from diffusers import pipelines

        # 3. Load each module in the pipeline
        for name, (library_name, class_name) in init_dict.items():
429
430
431
432
433
            if class_name is None:
                # edge case for when the pipeline was saved with safety_checker=None
                init_kwargs[name] = None
                continue

434
435
            is_pipeline_module = hasattr(pipelines, library_name)
            loaded_sub_model = None
436
            sub_model_should_be_defined = True
437
438
439
440
441
442
443
444

            # if the model is in a pipeline module, then we load it from the pipeline
            if name in passed_class_obj:
                # 1. check that passed_class_obj has correct parent class
                if not is_pipeline_module:
                    library = importlib.import_module(library_name)
                    class_obj = getattr(library, class_name)
                    importable_classes = LOADABLE_CLASSES[library_name]
445
                    class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
446
447
448

                    expected_class_obj = None
                    for class_name, class_candidate in class_candidates.items():
449
                        if class_candidate is not None and issubclass(class_obj, class_candidate):
450
451
452
453
454
455
456
                            expected_class_obj = class_candidate

                    if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
                        raise ValueError(
                            f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
                            f" {expected_class_obj}"
                        )
457
                elif passed_class_obj[name] is None:
458
                    logger.warning(
459
460
461
462
                        f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
                        f" that this might lead to problems when using {pipeline_class} and is not recommended."
                    )
                    sub_model_should_be_defined = False
463
                else:
464
                    logger.warning(
465
466
467
468
469
470
471
472
                        f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
                        " has the correct type"
                    )

                # set passed class object
                loaded_sub_model = passed_class_obj[name]
            elif is_pipeline_module:
                pipeline_module = getattr(pipelines, library_name)
473
                class_obj = import_flax_or_no_model(pipeline_module, class_name)
474
475
476
477
478
479

                importable_classes = ALL_IMPORTABLE_CLASSES
                class_candidates = {c: class_obj for c in importable_classes.keys()}
            else:
                # else we just import it from the library.
                library = importlib.import_module(library_name)
480
                class_obj = import_flax_or_no_model(library, class_name)
481
482

                importable_classes = LOADABLE_CLASSES[library_name]
483
                class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
484

485
            if loaded_sub_model is None and sub_model_should_be_defined:
486
487
                load_method_name = None
                for class_name, class_candidate in class_candidates.items():
488
                    if class_candidate is not None and issubclass(class_obj, class_candidate):
489
490
491
492
493
494
495
496
497
498
499
                        load_method_name = importable_classes[class_name][1]

                load_method = getattr(class_obj, load_method_name)

                # check if the module is in a subdirectory
                if os.path.isdir(os.path.join(cached_folder, name)):
                    loadable_folder = os.path.join(cached_folder, name)
                else:
                    loaded_sub_model = cached_folder

                if issubclass(class_obj, FlaxModelMixin):
500
501
502
503
                    loaded_sub_model, loaded_params = load_method(
                        loadable_folder,
                        from_pt=from_pt,
                        use_memory_efficient_attention=use_memory_efficient_attention,
504
                        split_head_dim=split_head_dim,
505
506
                        dtype=dtype,
                    )
507
508
                    params[name] = loaded_params
                elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
Pedro Cuenca's avatar
Pedro Cuenca committed
509
                    if from_pt:
510
511
512
513
514
515
516
                        # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
                        loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
                        loaded_params = loaded_sub_model.params
                        del loaded_sub_model._params
                    else:
                        loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
                    params[name] = loaded_params
517
                elif issubclass(class_obj, FlaxSchedulerMixin):
518
519
                    loaded_sub_model, scheduler_state = load_method(loadable_folder)
                    params[name] = scheduler_state
520
521
522
523
524
                else:
                    loaded_sub_model = load_method(loadable_folder)

            init_kwargs[name] = loaded_sub_model  # UNet(...), # DiffusionSchedule(...)

YiYi Xu's avatar
YiYi Xu committed
525
526
527
528
529
530
531
532
533
534
535
536
537
        # 4. Potentially add passed objects if expected
        missing_modules = set(expected_modules) - set(init_kwargs.keys())
        passed_modules = list(passed_class_obj.keys())

        if len(missing_modules) > 0 and missing_modules <= set(passed_modules):
            for module in missing_modules:
                init_kwargs[module] = passed_class_obj.get(module, None)
        elif len(missing_modules) > 0:
            passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
            raise ValueError(
                f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
            )

538
539
540
        model = pipeline_class(**init_kwargs, dtype=dtype)
        return model, params

541
542
    @classmethod
    def _get_signature_keys(cls, obj):
543
544
545
        parameters = inspect.signature(obj.__init__).parameters
        required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
        optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
546
        expected_modules = set(required_parameters.keys()) - {"self"}
547

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        return expected_modules, optional_parameters

    @property
    def components(self) -> Dict[str, Any]:
        r"""

        The `self.components` property can be useful to run different pipelines with the same weights and
        configurations to not have to re-allocate memory.

        Examples:

        ```py
        >>> from diffusers import (
        ...     FlaxStableDiffusionPipeline,
        ...     FlaxStableDiffusionImg2ImgPipeline,
        ... )

        >>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
        ...     "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
        ... )
        >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
        ```

        Returns:
            A dictionary containing all the modules needed to initialize the pipeline.
        """
        expected_modules, optional_parameters = self._get_signature_keys(self)
        components = {
            k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
        }

        if set(components.keys()) != expected_modules:
            raise ValueError(
                f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
                f" {expected_modules} to be defined, but {components} are defined."
            )

        return components

587
588
589
    @staticmethod
    def numpy_to_pil(images):
        """
590
        Convert a NumPy image or a batch of images to a PIL image.
591
592
593
594
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
595
596
597
598
599
        if images.shape[-1] == 1:
            # special case for grayscale (single channel) images
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            pil_images = [Image.fromarray(image) for image in images]
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

        return pil_images

    # TODO: make it compatible with jax.lax
    def progress_bar(self, iterable):
        if not hasattr(self, "_progress_bar_config"):
            self._progress_bar_config = {}
        elif not isinstance(self._progress_bar_config, dict):
            raise ValueError(
                f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
            )

        return tqdm(iterable, **self._progress_bar_config)

    def set_progress_bar_config(self, **kwargs):
        self._progress_bar_config = kwargs