pipeline_dit.py 10 KB
Newer Older
Kashif Rasul's avatar
Kashif Rasul committed
1
2
3
4
5
6
# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
# William Peebles and Saining Xie
#
# Copyright (c) 2021 OpenAI
# MIT License
#
7
# Copyright 2024 The HuggingFace Team. All rights reserved.
Kashif Rasul's avatar
Kashif Rasul committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple, Union

import torch

25
from ...models import AutoencoderKL, DiTTransformer2DModel
Kashif Rasul's avatar
Kashif Rasul committed
26
from ...schedulers import KarrasDiffusionSchedulers
hlky's avatar
hlky committed
27
from ...utils import is_torch_xla_available
Dhruv Nair's avatar
Dhruv Nair committed
28
from ...utils.torch_utils import randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
29
30
31
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


hlky's avatar
hlky committed
32
33
34
35
36
37
38
39
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
else:
    XLA_AVAILABLE = False


Kashif Rasul's avatar
Kashif Rasul committed
40
41
class DiTPipeline(DiffusionPipeline):
    r"""
42
43
44
45
    Pipeline for image generation based on a Transformer backbone instead of a UNet.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Kashif Rasul's avatar
Kashif Rasul committed
46
47

    Parameters:
48
49
        transformer ([`DiTTransformer2DModel`]):
            A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents.
Kashif Rasul's avatar
Kashif Rasul committed
50
        vae ([`AutoencoderKL`]):
51
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
Kashif Rasul's avatar
Kashif Rasul committed
52
        scheduler ([`DDIMScheduler`]):
53
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
Kashif Rasul's avatar
Kashif Rasul committed
54
    """
55

56
    model_cpu_offload_seq = "transformer->vae"
Kashif Rasul's avatar
Kashif Rasul committed
57
58
59

    def __init__(
        self,
60
        transformer: DiTTransformer2DModel,
Kashif Rasul's avatar
Kashif Rasul committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        vae: AutoencoderKL,
        scheduler: KarrasDiffusionSchedulers,
        id2label: Optional[Dict[int, str]] = None,
    ):
        super().__init__()
        self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)

        # create a imagenet -> id dictionary for easier use
        self.labels = {}
        if id2label is not None:
            for key, value in id2label.items():
                for label in value.split(","):
                    self.labels[label.lstrip().rstrip()] = int(key)
            self.labels = dict(sorted(self.labels.items()))

    def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
        r"""

79
        Map label strings from ImageNet to corresponding class ids.
Kashif Rasul's avatar
Kashif Rasul committed
80
81

        Parameters:
82
83
            label (`str` or `dict` of `str`):
                Label strings to be mapped to class ids.
Kashif Rasul's avatar
Kashif Rasul committed
84
85

        Returns:
86
87
            `list` of `int`:
                Class ids to be processed by pipeline.
Kashif Rasul's avatar
Kashif Rasul committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        """

        if not isinstance(label, list):
            label = list(label)

        for l in label:
            if l not in self.labels:
                raise ValueError(
                    f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}."
                )

        return [self.labels[l] for l in label]

    @torch.no_grad()
    def __call__(
        self,
        class_labels: List[int],
        guidance_scale: float = 4.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        num_inference_steps: int = 50,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple]:
        r"""
112
        The call function to the pipeline for generation.
Kashif Rasul's avatar
Kashif Rasul committed
113
114
115

        Args:
            class_labels (List[int]):
116
                List of ImageNet class labels for the images to be generated.
Kashif Rasul's avatar
Kashif Rasul committed
117
            guidance_scale (`float`, *optional*, defaults to 4.0):
118
119
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Kashif Rasul's avatar
Kashif Rasul committed
120
            generator (`torch.Generator`, *optional*):
121
122
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
Kashif Rasul's avatar
Kashif Rasul committed
123
124
125
126
            num_inference_steps (`int`, *optional*, defaults to 250):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            output_type (`str`, *optional*, defaults to `"pil"`):
127
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
Kashif Rasul's avatar
Kashif Rasul committed
128
129
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        Examples:

        ```py
        >>> from diffusers import DiTPipeline, DPMSolverMultistepScheduler
        >>> import torch

        >>> pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
        >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        >>> pipe = pipe.to("cuda")

        >>> # pick words from Imagenet class labels
        >>> pipe.labels  # to print all available words

        >>> # pick words that exist in ImageNet
        >>> words = ["white shark", "umbrella"]

        >>> class_ids = pipe.get_label_ids(words)

        >>> generator = torch.manual_seed(33)
        >>> output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)

        >>> image = output.images[0]  # label 'white shark'
        ```

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
Kashif Rasul's avatar
Kashif Rasul committed
159
160
161
162
163
164
165
166
167
        """

        batch_size = len(class_labels)
        latent_size = self.transformer.config.sample_size
        latent_channels = self.transformer.config.in_channels

        latents = randn_tensor(
            shape=(batch_size, latent_channels, latent_size, latent_size),
            generator=generator,
168
            device=self._execution_device,
Kashif Rasul's avatar
Kashif Rasul committed
169
170
171
172
            dtype=self.transformer.dtype,
        )
        latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents

173
174
        class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
        class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
Kashif Rasul's avatar
Kashif Rasul committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels

        # set step values
        self.scheduler.set_timesteps(num_inference_steps)
        for t in self.progress_bar(self.scheduler.timesteps):
            if guidance_scale > 1:
                half = latent_model_input[: len(latent_model_input) // 2]
                latent_model_input = torch.cat([half, half], dim=0)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            timesteps = t
            if not torch.is_tensor(timesteps):
                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                # This would be a good case for the `match` statement (Python 3.10+)
                is_mps = latent_model_input.device.type == "mps"
190
                is_npu = latent_model_input.device.type == "npu"
Kashif Rasul's avatar
Kashif Rasul committed
191
                if isinstance(timesteps, float):
192
                    dtype = torch.float32 if (is_mps or is_npu) else torch.float64
Kashif Rasul's avatar
Kashif Rasul committed
193
                else:
194
                    dtype = torch.int32 if (is_mps or is_npu) else torch.int64
Kashif Rasul's avatar
Kashif Rasul committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
            elif len(timesteps.shape) == 0:
                timesteps = timesteps[None].to(latent_model_input.device)
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timesteps = timesteps.expand(latent_model_input.shape[0])
            # predict noise model_output
            noise_pred = self.transformer(
                latent_model_input, timestep=timesteps, class_labels=class_labels_input
            ).sample

            # perform guidance
            if guidance_scale > 1:
                eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
                cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)

                half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
                eps = torch.cat([half_eps, half_eps], dim=0)

                noise_pred = torch.cat([eps, rest], dim=1)

            # learned sigma
            if self.transformer.config.out_channels // 2 == latent_channels:
                model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
            else:
                model_output = noise_pred

            # compute previous image: x_t -> x_t-1
            latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample

hlky's avatar
hlky committed
224
225
226
            if XLA_AVAILABLE:
                xm.mark_step()

Kashif Rasul's avatar
Kashif Rasul committed
227
228
229
230
231
        if guidance_scale > 1:
            latents, _ = latent_model_input.chunk(2, dim=0)
        else:
            latents = latent_model_input

232
        latents = 1 / self.vae.config.scaling_factor * latents
Kashif Rasul's avatar
Kashif Rasul committed
233
234
235
236
237
238
239
240
241
242
        samples = self.vae.decode(latents).sample

        samples = (samples / 2 + 0.5).clamp(0, 1)

        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()

        if output_type == "pil":
            samples = self.numpy_to_pil(samples)

243
244
245
        # Offload all models
        self.maybe_free_model_hooks()

Kashif Rasul's avatar
Kashif Rasul committed
246
247
248
249
        if not return_dict:
            return (samples,)

        return ImagePipelineOutput(images=samples)