pipeline_unclip.py 21.7 KB
Newer Older
1
# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved.
Will Berman's avatar
Will Berman committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
16
from typing import List, Optional, Tuple, Union
Will Berman's avatar
Will Berman committed
17
18
19
20

import torch
from torch.nn import functional as F
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
21
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
Will Berman's avatar
Will Berman committed
22

23
24
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
Dhruv Nair's avatar
Dhruv Nair committed
25
26
from ...utils import logging
from ...utils.torch_utils import randn_tensor
YiYi Xu's avatar
YiYi Xu committed
27
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
Will Berman's avatar
Will Berman committed
28
29
30
31
32
33
34
from .text_proj import UnCLIPTextProjModel


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class UnCLIPPipeline(DiffusionPipeline):
Will Berman's avatar
Will Berman committed
35
    """
36
    Pipeline for text-to-image generation using unCLIP.
Will Berman's avatar
Will Berman committed
37

38
39
    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Will Berman's avatar
Will Berman committed
40
41

    Args:
42
        text_encoder ([`~transformers.CLIPTextModelWithProjection`]):
Will Berman's avatar
Will Berman committed
43
            Frozen text-encoder.
44
45
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
Will Berman's avatar
Will Berman committed
46
        prior ([`PriorTransformer`]):
47
            The canonical unCLIP prior to approximate the image embedding from the text embedding.
Will Berman's avatar
Will Berman committed
48
49
        text_proj ([`UnCLIPTextProjModel`]):
            Utility class to prepare and combine the embeddings before they are passed to the decoder.
Will Berman's avatar
Will Berman committed
50
51
52
        decoder ([`UNet2DConditionModel`]):
            The decoder to invert the image embedding into an image.
        super_res_first ([`UNet2DModel`]):
53
            Super resolution UNet. Used in all but the last step of the super resolution diffusion process.
Will Berman's avatar
Will Berman committed
54
        super_res_last ([`UNet2DModel`]):
55
            Super resolution UNet. Used in the last step of the super resolution diffusion process.
Will Berman's avatar
Will Berman committed
56
        prior_scheduler ([`UnCLIPScheduler`]):
57
            Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]).
Will Berman's avatar
Will Berman committed
58
        decoder_scheduler ([`UnCLIPScheduler`]):
59
            Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]).
Will Berman's avatar
Will Berman committed
60
        super_res_scheduler ([`UnCLIPScheduler`]):
61
            Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]).
Will Berman's avatar
Will Berman committed
62
63
64

    """

65
66
    _exclude_from_cpu_offload = ["prior"]

Will Berman's avatar
Will Berman committed
67
68
69
70
71
72
73
74
75
76
77
78
    prior: PriorTransformer
    decoder: UNet2DConditionModel
    text_proj: UnCLIPTextProjModel
    text_encoder: CLIPTextModelWithProjection
    tokenizer: CLIPTokenizer
    super_res_first: UNet2DModel
    super_res_last: UNet2DModel

    prior_scheduler: UnCLIPScheduler
    decoder_scheduler: UnCLIPScheduler
    super_res_scheduler: UnCLIPScheduler

79
80
    model_cpu_offload_seq = "text_encoder->text_proj->decoder->super_res_first->super_res_last"

Will Berman's avatar
Will Berman committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    def __init__(
        self,
        prior: PriorTransformer,
        decoder: UNet2DConditionModel,
        text_encoder: CLIPTextModelWithProjection,
        tokenizer: CLIPTokenizer,
        text_proj: UnCLIPTextProjModel,
        super_res_first: UNet2DModel,
        super_res_last: UNet2DModel,
        prior_scheduler: UnCLIPScheduler,
        decoder_scheduler: UnCLIPScheduler,
        super_res_scheduler: UnCLIPScheduler,
    ):
        super().__init__()

        self.register_modules(
            prior=prior,
            decoder=decoder,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            text_proj=text_proj,
            super_res_first=super_res_first,
            super_res_last=super_res_last,
            prior_scheduler=prior_scheduler,
            decoder_scheduler=decoder_scheduler,
            super_res_scheduler=super_res_scheduler,
        )

    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
        if latents is None:
111
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Will Berman's avatar
Will Berman committed
112
113
114
115
116
117
118
119
        else:
            if latents.shape != shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
            latents = latents.to(device)

        latents = latents * scheduler.init_noise_sigma
        return latents

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
        text_attention_mask: Optional[torch.Tensor] = None,
    ):
        if text_model_output is None:
            batch_size = len(prompt) if isinstance(prompt, list) else 1
            # get prompt text embeddings
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
136
                truncation=True,
137
                return_tensors="pt",
Will Berman's avatar
Will Berman committed
138
            )
139
140
141
            text_input_ids = text_inputs.input_ids
            text_mask = text_inputs.attention_mask.bool().to(device)

142
143
144
145
146
147
148
149
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
150
151
152
153
154
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )
                text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
Will Berman's avatar
Will Berman committed
155

156
            text_encoder_output = self.text_encoder(text_input_ids.to(device))
Will Berman's avatar
Will Berman committed
157

158
            prompt_embeds = text_encoder_output.text_embeds
159
            text_enc_hid_states = text_encoder_output.last_hidden_state
160
161
162

        else:
            batch_size = text_model_output[0].shape[0]
163
            prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
164
            text_mask = text_attention_mask
Will Berman's avatar
Will Berman committed
165

166
        prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
167
        text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
168
        text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
Will Berman's avatar
Will Berman committed
169
170
171
172
173
174
175

        if do_classifier_free_guidance:
            uncond_tokens = [""] * batch_size

            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
176
                max_length=self.tokenizer.model_max_length,
Will Berman's avatar
Will Berman committed
177
178
179
                truncation=True,
                return_tensors="pt",
            )
180
            uncond_text_mask = uncond_input.attention_mask.bool().to(device)
181
            negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
Will Berman's avatar
Will Berman committed
182

183
            negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
184
            uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
Will Berman's avatar
Will Berman committed
185
186
187

            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method

188
189
190
            seq_len = negative_prompt_embeds.shape[1]
            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
Will Berman's avatar
Will Berman committed
191

192
193
194
            seq_len = uncond_text_enc_hid_states.shape[1]
            uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
            uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
Will Berman's avatar
Will Berman committed
195
196
                batch_size * num_images_per_prompt, seq_len, -1
            )
197
            uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
Will Berman's avatar
Will Berman committed
198
199
200
201
202
203

            # done duplicates

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
204
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
205
            text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
Will Berman's avatar
Will Berman committed
206
207
208

            text_mask = torch.cat([uncond_text_mask, text_mask])

209
        return prompt_embeds, text_enc_hid_states, text_mask
Will Berman's avatar
Will Berman committed
210
211
212
213

    @torch.no_grad()
    def __call__(
        self,
214
        prompt: Optional[Union[str, List[str]]] = None,
Will Berman's avatar
Will Berman committed
215
216
217
218
        num_images_per_prompt: int = 1,
        prior_num_inference_steps: int = 25,
        decoder_num_inference_steps: int = 25,
        super_res_num_inference_steps: int = 7,
219
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Will Berman's avatar
Will Berman committed
220
221
222
        prior_latents: Optional[torch.FloatTensor] = None,
        decoder_latents: Optional[torch.FloatTensor] = None,
        super_res_latents: Optional[torch.FloatTensor] = None,
223
224
        text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
        text_attention_mask: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
225
226
227
228
229
        prior_guidance_scale: float = 4.0,
        decoder_guidance_scale: float = 8.0,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ):
Will Berman's avatar
Will Berman committed
230
        """
231
        The call function to the pipeline for generation.
Will Berman's avatar
Will Berman committed
232
233
234

        Args:
            prompt (`str` or `List[str]`):
235
236
                The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output`
                and `text_attention_mask` is passed.
Will Berman's avatar
Will Berman committed
237
238
239
240
241
242
243
244
245
246
247
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            prior_num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
                image at the expense of slower inference.
            decoder_num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
                image at the expense of slower inference.
            super_res_num_inference_steps (`int`, *optional*, defaults to 7):
                The number of denoising steps for super resolution. More denoising steps usually lead to a higher
                quality image at the expense of slower inference.
248
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
249
250
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
Will Berman's avatar
Will Berman committed
251
252
253
254
255
256
257
            prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
                Pre-generated noisy latents to be used as inputs for the prior.
            decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
                Pre-generated noisy latents to be used as inputs for the decoder.
            super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
                Pre-generated noisy latents to be used as inputs for the decoder.
            prior_guidance_scale (`float`, *optional*, defaults to 4.0):
258
259
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
Will Berman's avatar
Will Berman committed
260
            decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
261
262
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
263
            text_model_output (`CLIPTextModelOutput`, *optional*):
264
265
266
                Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text
                outputs can be passed for tasks like text embedding interpolations. Make sure to also pass
                `text_attention_mask` in this case. `prompt` can the be left `None`.
267
268
269
            text_attention_mask (`torch.Tensor`, *optional*):
                Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
                masks are necessary when passing `text_model_output`.
Will Berman's avatar
Will Berman committed
270
            output_type (`str`, *optional*, defaults to `"pil"`):
271
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
Will Berman's avatar
Will Berman committed
272
            return_dict (`bool`, *optional*, defaults to `True`):
273
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
274
275
276
277
278

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
Will Berman's avatar
Will Berman committed
279
        """
280
281
282
283
284
285
286
        if prompt is not None:
            if isinstance(prompt, str):
                batch_size = 1
            elif isinstance(prompt, list):
                batch_size = len(prompt)
            else:
                raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
Will Berman's avatar
Will Berman committed
287
        else:
288
289
            batch_size = text_model_output[0].shape[0]

290
        device = self._execution_device
Will Berman's avatar
Will Berman committed
291
292
293
294
295

        batch_size = batch_size * num_images_per_prompt

        do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0

296
        prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
297
            prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
Will Berman's avatar
Will Berman committed
298
299
300
301
        )

        # prior

302
        self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
Will Berman's avatar
Will Berman committed
303
304
305
        prior_timesteps_tensor = self.prior_scheduler.timesteps

        embedding_dim = self.prior.config.embedding_dim
306

Will Berman's avatar
Will Berman committed
307
308
        prior_latents = self.prepare_latents(
            (batch_size, embedding_dim),
309
            prompt_embeds.dtype,
310
            device,
Will Berman's avatar
Will Berman committed
311
312
313
314
315
316
317
318
319
320
321
322
            generator,
            prior_latents,
            self.prior_scheduler,
        )

        for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents

            predicted_image_embedding = self.prior(
                latent_model_input,
                timestep=t,
323
                proj_embedding=prompt_embeds,
324
                encoder_hidden_states=text_enc_hid_states,
Will Berman's avatar
Will Berman committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
                attention_mask=text_mask,
            ).predicted_image_embedding

            if do_classifier_free_guidance:
                predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
                predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
                    predicted_image_embedding_text - predicted_image_embedding_uncond
                )

            if i + 1 == prior_timesteps_tensor.shape[0]:
                prev_timestep = None
            else:
                prev_timestep = prior_timesteps_tensor[i + 1]

            prior_latents = self.prior_scheduler.step(
                predicted_image_embedding,
                timestep=t,
                sample=prior_latents,
                generator=generator,
                prev_timestep=prev_timestep,
            ).prev_sample

        prior_latents = self.prior.post_process_latents(prior_latents)

        image_embeddings = prior_latents

        # done prior

        # decoder

355
        text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
Will Berman's avatar
Will Berman committed
356
            image_embeddings=image_embeddings,
357
            prompt_embeds=prompt_embeds,
358
            text_encoder_hidden_states=text_enc_hid_states,
Will Berman's avatar
Will Berman committed
359
360
361
            do_classifier_free_guidance=do_classifier_free_guidance,
        )

362
363
364
365
366
367
368
369
        if device.type == "mps":
            # HACK: MPS: There is a panic when padding bool tensors,
            # so cast to int tensor for the pad and back to bool afterwards
            text_mask = text_mask.type(torch.int)
            decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
            decoder_text_mask = decoder_text_mask.type(torch.bool)
        else:
            decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
Will Berman's avatar
Will Berman committed
370

371
        self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
Will Berman's avatar
Will Berman committed
372
373
        decoder_timesteps_tensor = self.decoder_scheduler.timesteps

374
375
376
        num_channels_latents = self.decoder.config.in_channels
        height = self.decoder.config.sample_size
        width = self.decoder.config.sample_size
377

Will Berman's avatar
Will Berman committed
378
379
        decoder_latents = self.prepare_latents(
            (batch_size, num_channels_latents, height, width),
380
            text_enc_hid_states.dtype,
381
            device,
Will Berman's avatar
Will Berman committed
382
383
384
385
386
387
388
389
390
391
392
393
            generator,
            decoder_latents,
            self.decoder_scheduler,
        )

        for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents

            noise_pred = self.decoder(
                sample=latent_model_input,
                timestep=t,
394
                encoder_hidden_states=text_enc_hid_states,
Will Berman's avatar
Will Berman committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
                class_labels=additive_clip_time_embeddings,
                attention_mask=decoder_text_mask,
            ).sample

            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
                noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
                noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
                noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

            if i + 1 == decoder_timesteps_tensor.shape[0]:
                prev_timestep = None
            else:
                prev_timestep = decoder_timesteps_tensor[i + 1]

            # compute the previous noisy sample x_t -> x_t-1
            decoder_latents = self.decoder_scheduler.step(
413
                noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
Will Berman's avatar
Will Berman committed
414
415
416
417
418
419
420
421
422
423
            ).prev_sample

        decoder_latents = decoder_latents.clamp(-1, 1)

        image_small = decoder_latents

        # done decoder

        # super res

424
        self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
Will Berman's avatar
Will Berman committed
425
426
        super_res_timesteps_tensor = self.super_res_scheduler.timesteps

427
428
429
        channels = self.super_res_first.config.in_channels // 2
        height = self.super_res_first.config.sample_size
        width = self.super_res_first.config.sample_size
430

Will Berman's avatar
Will Berman committed
431
432
433
        super_res_latents = self.prepare_latents(
            (batch_size, channels, height, width),
            image_small.dtype,
434
            device,
Will Berman's avatar
Will Berman committed
435
436
437
438
439
            generator,
            super_res_latents,
            self.super_res_scheduler,
        )

440
441
442
443
444
445
446
        if device.type == "mps":
            # MPS does not support many interpolations
            image_upscaled = F.interpolate(image_small, size=[height, width])
        else:
            interpolate_antialias = {}
            if "antialias" in inspect.signature(F.interpolate).parameters:
                interpolate_antialias["antialias"] = True
Will Berman's avatar
Will Berman committed
447

448
449
450
            image_upscaled = F.interpolate(
                image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
            )
Will Berman's avatar
Will Berman committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

        for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
            # no classifier free guidance

            if i == super_res_timesteps_tensor.shape[0] - 1:
                unet = self.super_res_last
            else:
                unet = self.super_res_first

            latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)

            noise_pred = unet(
                sample=latent_model_input,
                timestep=t,
            ).sample

            if i + 1 == super_res_timesteps_tensor.shape[0]:
                prev_timestep = None
            else:
                prev_timestep = super_res_timesteps_tensor[i + 1]

            # compute the previous noisy sample x_t -> x_t-1
            super_res_latents = self.super_res_scheduler.step(
474
                noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
Will Berman's avatar
Will Berman committed
475
476
477
478
479
            ).prev_sample

        image = super_res_latents
        # done super res

480
        self.maybe_free_model_hooks()
Will Berman's avatar
Will Berman committed
481

482
        # post processing
Will Berman's avatar
Will Berman committed
483
484
485
486
487
488
489
490
491
492
493
        image = image * 0.5 + 0.5
        image = image.clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()

        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)