pipeline_wuerstchen_combined.py 16.3 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
Kashif Rasul's avatar
Kashif Rasul committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from typing import Callable, Dict, List, Optional, Union
Kashif Rasul's avatar
Kashif Rasul committed
15
16
17
18
19

import torch
from transformers import CLIPTextModel, CLIPTokenizer

from ...schedulers import DDPMWuerstchenScheduler
20
from ...utils import deprecate, replace_example_docstring
21
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
Kashif Rasul's avatar
Kashif Rasul committed
22
23
24
25
26
27
28
29
30
31
32
33
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline


TEXT2IMAGE_EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> from diffusions import WuerstchenCombinedPipeline

Kashif Rasul's avatar
Kashif Rasul committed
34
35
36
        >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to(
        ...     "cuda"
        ... )
Kashif Rasul's avatar
Kashif Rasul committed
37
38
39
40
41
42
        >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
        >>> images = pipe(prompt=prompt)
        ```
"""


43
class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
Kashif Rasul's avatar
Kashif Rasul committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """
    Combined Pipeline for text-to-image generation using Wuerstchen

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        tokenizer (`CLIPTokenizer`):
            The decoder tokenizer to be used for text inputs.
        text_encoder (`CLIPTextModel`):
            The decoder text encoder to be used for text inputs.
        decoder (`WuerstchenDiffNeXt`):
            The decoder model to be used for decoder image generation pipeline.
        scheduler (`DDPMWuerstchenScheduler`):
            The scheduler to be used for decoder image generation pipeline.
        vqgan (`PaellaVQModel`):
            The VQGAN model to be used for decoder image generation pipeline.
        prior_tokenizer (`CLIPTokenizer`):
            The prior tokenizer to be used for text inputs.
        prior_text_encoder (`CLIPTextModel`):
            The prior text encoder to be used for text inputs.
65
        prior_prior (`WuerstchenPrior`):
Kashif Rasul's avatar
Kashif Rasul committed
66
67
68
69
70
            The prior model to be used for prior pipeline.
        prior_scheduler (`DDPMWuerstchenScheduler`):
            The scheduler to be used for prior pipeline.
    """

71
    _last_supported_version = "0.33.1"
Kashif Rasul's avatar
Kashif Rasul committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    _load_connected_pipes = True

    def __init__(
        self,
        tokenizer: CLIPTokenizer,
        text_encoder: CLIPTextModel,
        decoder: WuerstchenDiffNeXt,
        scheduler: DDPMWuerstchenScheduler,
        vqgan: PaellaVQModel,
        prior_tokenizer: CLIPTokenizer,
        prior_text_encoder: CLIPTextModel,
        prior_prior: WuerstchenPrior,
        prior_scheduler: DDPMWuerstchenScheduler,
    ):
        super().__init__()

        self.register_modules(
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            decoder=decoder,
            scheduler=scheduler,
            vqgan=vqgan,
            prior_prior=prior_prior,
            prior_text_encoder=prior_text_encoder,
            prior_tokenizer=prior_tokenizer,
            prior_scheduler=prior_scheduler,
        )
        self.prior_pipe = WuerstchenPriorPipeline(
            prior=prior_prior,
            text_encoder=prior_text_encoder,
            tokenizer=prior_tokenizer,
            scheduler=prior_scheduler,
        )
        self.decoder_pipe = WuerstchenDecoderPipeline(
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            decoder=decoder,
            scheduler=scheduler,
            vqgan=vqgan,
        )

    def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
        self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)

116
    def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
Kashif Rasul's avatar
Kashif Rasul committed
117
118
119
120
121
122
        r"""
        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
        """
123
124
        self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
        self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
Kashif Rasul's avatar
Kashif Rasul committed
125

126
    def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
Kashif Rasul's avatar
Kashif Rasul committed
127
128
129
130
131
132
        r"""
        Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
        Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
        GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
        Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
        """
133
134
        self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
        self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
Kashif Rasul's avatar
Kashif Rasul committed
135
136
137
138
139
140
141
142
143
144
145
146
147

    def progress_bar(self, iterable=None, total=None):
        self.prior_pipe.progress_bar(iterable=iterable, total=total)
        self.decoder_pipe.progress_bar(iterable=iterable, total=total)

    def set_progress_bar_config(self, **kwargs):
        self.prior_pipe.set_progress_bar_config(**kwargs)
        self.decoder_pipe.set_progress_bar_config(**kwargs)

    @torch.no_grad()
    @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
    def __call__(
        self,
148
        prompt: Optional[Union[str, List[str]]] = None,
Kashif Rasul's avatar
Kashif Rasul committed
149
150
151
152
        height: int = 512,
        width: int = 512,
        prior_num_inference_steps: int = 60,
        prior_timesteps: Optional[List[float]] = None,
Kashif Rasul's avatar
Kashif Rasul committed
153
154
155
156
157
        prior_guidance_scale: float = 4.0,
        num_inference_steps: int = 12,
        decoder_timesteps: Optional[List[float]] = None,
        decoder_guidance_scale: float = 0.0,
        negative_prompt: Optional[Union[str, List[str]]] = None,
158
159
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
Kashif Rasul's avatar
Kashif Rasul committed
160
        num_images_per_prompt: int = 1,
Kashif Rasul's avatar
Kashif Rasul committed
161
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
162
        latents: Optional[torch.Tensor] = None,
Kashif Rasul's avatar
Kashif Rasul committed
163
164
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
165
166
167
168
169
        prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        **kwargs,
Kashif Rasul's avatar
Kashif Rasul committed
170
171
172
173
174
175
    ):
        """
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`):
176
                The prompt or prompts to guide the image generation for the prior and decoder.
Kashif Rasul's avatar
Kashif Rasul committed
177
178
179
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
                if `guidance_scale` is less than `1`).
180
            prompt_embeds (`torch.Tensor`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
181
182
                Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, text embeddings will be generated from `prompt` input argument.
183
            negative_prompt_embeds (`torch.Tensor`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
184
185
186
                Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
                prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
                input argument.
Kashif Rasul's avatar
Kashif Rasul committed
187
188
189
190
191
192
193
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            height (`int`, *optional*, defaults to 512):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to 512):
                The width in pixels of the generated image.
            prior_guidance_scale (`float`, *optional*, defaults to 4.0):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
194
195
196
197
198
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of
                equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
                setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
                closely linked to the text `prompt`, usually at the expense of lower image quality.
199
            prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
Kashif Rasul's avatar
Kashif Rasul committed
200
                The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
Kashif Rasul's avatar
Kashif Rasul committed
201
202
203
                expense of slower inference. For more specific timestep spacing, you can pass customized
                `prior_timesteps`
            num_inference_steps (`int`, *optional*, defaults to 12):
Kashif Rasul's avatar
Kashif Rasul committed
204
205
206
                The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
                the expense of slower inference. For more specific timestep spacing, you can pass customized
                `timesteps`
Kashif Rasul's avatar
Kashif Rasul committed
207
208
209
            prior_timesteps (`List[float]`, *optional*):
                Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced
                `prior_num_inference_steps` timesteps are used. Must be in descending order.
Kashif Rasul's avatar
Kashif Rasul committed
210
            decoder_timesteps (`List[float]`, *optional*):
Kashif Rasul's avatar
Kashif Rasul committed
211
                Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
Kashif Rasul's avatar
Kashif Rasul committed
212
213
                `num_inference_steps` timesteps are used. Must be in descending order.
            decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
Quentin Gallouédec's avatar
Quentin Gallouédec committed
214
215
216
217
218
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
Kashif Rasul's avatar
Kashif Rasul committed
219
220
221
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
222
            latents (`torch.Tensor`, *optional*):
Kashif Rasul's avatar
Kashif Rasul committed
223
224
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
225
                tensor will be generated by sampling using the supplied random `generator`.
Kashif Rasul's avatar
Kashif Rasul committed
226
227
228
229
230
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
                (`np.array`) or `"pt"` (`torch.Tensor`).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
231
232
233
234
235
236
237
            prior_callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
                int, callback_kwargs: Dict)`.
            prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
                list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
Steven Liu's avatar
Steven Liu committed
238
                the `._callback_tensor_inputs` attribute of your pipeline class.
239
240
241
242
243
244
245
246
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
Steven Liu's avatar
Steven Liu committed
247
                `._callback_tensor_inputs` attribute of your pipeline class.
Kashif Rasul's avatar
Kashif Rasul committed
248
249
250
251
252
253
254

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
            otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
        """
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        prior_kwargs = {}
        if kwargs.get("prior_callback", None) is not None:
            prior_kwargs["callback"] = kwargs.pop("prior_callback")
            deprecate(
                "prior_callback",
                "1.0.0",
                "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
            )
        if kwargs.get("prior_callback_steps", None) is not None:
            deprecate(
                "prior_callback_steps",
                "1.0.0",
                "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
            )
            prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps")

Kashif Rasul's avatar
Kashif Rasul committed
271
        prior_outputs = self.prior_pipe(
272
            prompt=prompt if prompt_embeds is None else None,
Kashif Rasul's avatar
Kashif Rasul committed
273
            height=height,
Kashif Rasul's avatar
Kashif Rasul committed
274
            width=width,
Kashif Rasul's avatar
Kashif Rasul committed
275
276
            num_inference_steps=prior_num_inference_steps,
            timesteps=prior_timesteps,
Kashif Rasul's avatar
Kashif Rasul committed
277
            guidance_scale=prior_guidance_scale,
278
279
280
            negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
Kashif Rasul's avatar
Kashif Rasul committed
281
            num_images_per_prompt=num_images_per_prompt,
Kashif Rasul's avatar
Kashif Rasul committed
282
283
284
285
            generator=generator,
            latents=latents,
            output_type="pt",
            return_dict=False,
286
287
288
            callback_on_step_end=prior_callback_on_step_end,
            callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
            **prior_kwargs,
Kashif Rasul's avatar
Kashif Rasul committed
289
290
291
292
293
        )
        image_embeddings = prior_outputs[0]

        outputs = self.decoder_pipe(
            image_embeddings=image_embeddings,
294
            prompt=prompt if prompt is not None else "",
Kashif Rasul's avatar
Kashif Rasul committed
295
            num_inference_steps=num_inference_steps,
Kashif Rasul's avatar
Kashif Rasul committed
296
297
298
            timesteps=decoder_timesteps,
            guidance_scale=decoder_guidance_scale,
            negative_prompt=negative_prompt,
Kashif Rasul's avatar
Kashif Rasul committed
299
300
301
            generator=generator,
            output_type=output_type,
            return_dict=return_dict,
302
303
304
            callback_on_step_end=callback_on_step_end,
            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
            **kwargs,
Kashif Rasul's avatar
Kashif Rasul committed
305
        )
306

Kashif Rasul's avatar
Kashif Rasul committed
307
        return outputs