pipeline_ddpm.py 5.22 KB
Newer Older
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16
from typing import List, Optional, Tuple, Union
Pedro Cuenca's avatar
Pedro Cuenca committed
17

Patrick von Platen's avatar
Patrick von Platen committed
18
19
import torch

20
21
from ...models import UNet2DModel
from ...schedulers import DDPMScheduler
hlky's avatar
hlky committed
22
from ...utils import is_torch_xla_available
Dhruv Nair's avatar
Dhruv Nair committed
23
from ...utils.torch_utils import randn_tensor
24
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
Patrick von Platen's avatar
Patrick von Platen committed
25
26


hlky's avatar
hlky committed
27
28
29
30
31
32
33
34
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


Patrick von Platen's avatar
Patrick von Platen committed
35
class DDPMPipeline(DiffusionPipeline):
36
    r"""
37
38
39
40
    Pipeline for image generation.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
41
42

    Parameters:
43
44
        unet ([`UNet2DModel`]):
            A `UNet2DModel` to denoise the encoded image latents.
45
46
47
48
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
            [`DDPMScheduler`], or [`DDIMScheduler`].
    """
49

50
    model_cpu_offload_seq = "unet"
51

52
    def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
Patrick von Platen's avatar
Patrick von Platen committed
53
        super().__init__()
54
        self.register_modules(unet=unet, scheduler=scheduler)
Patrick von Platen's avatar
Patrick von Platen committed
55

Patrick von Platen's avatar
Patrick von Platen committed
56
    @torch.no_grad()
57
    def __call__(
Sid Sahai's avatar
Sid Sahai committed
58
59
        self,
        batch_size: int = 1,
60
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61
        num_inference_steps: int = 1000,
Sid Sahai's avatar
Sid Sahai committed
62
63
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
64
    ) -> Union[ImagePipelineOutput, Tuple]:
65
        r"""
66
67
        The call function to the pipeline for generation.

68
        Args:
69
            batch_size (`int`, *optional*, defaults to 1):
70
                The number of images to generate.
71
            generator (`torch.Generator`, *optional*):
72
73
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
74
75
76
            num_inference_steps (`int`, *optional*, defaults to 1000):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
77
            output_type (`str`, *optional*, defaults to `"pil"`):
78
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
79
            return_dict (`bool`, *optional*, defaults to `True`):
80
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        Example:

        ```py
        >>> from diffusers import DDPMPipeline

        >>> # load model and scheduler
        >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")

        >>> # run pipeline in inference (sample random noise and denoise)
        >>> image = pipe().images[0]

        >>> # save image
        >>> image.save("ddpm_generated_image.png")
        ```

97
        Returns:
98
99
100
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
101
        """
Patrick von Platen's avatar
Patrick von Platen committed
102
        # Sample gaussian noise to begin loop
103
104
105
106
107
108
109
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
110
        else:
111
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
112

113
114
        if self.device.type == "mps":
            # randn does not work reproducibly on mps
115
            image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
116
117
            image = image.to(self.device)
        else:
118
            image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
119

120
        # set step values
121
        self.scheduler.set_timesteps(num_inference_steps)
122

hysts's avatar
hysts committed
123
        for t in self.progress_bar(self.scheduler.timesteps):
Patrick von Platen's avatar
Patrick von Platen committed
124
            # 1. predict noise model_output
125
            model_output = self.unet(image, t).sample
Patrick von Platen's avatar
Patrick von Platen committed
126

127
            # 2. compute previous image: x_t -> x_t-1
128
            image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
Patrick von Platen's avatar
Patrick von Platen committed
129

hlky's avatar
hlky committed
130
131
132
            if XLA_AVAILABLE:
                xm.mark_step()

133
134
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
135
136
        if output_type == "pil":
            image = self.numpy_to_pil(image)
137

138
139
140
141
        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)