pipeline_stable_diffusion.py 46.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Suraj Patil's avatar
Suraj Patil committed
15
import inspect
16
from typing import Any, Callable, Dict, List, Optional, Union
Suraj Patil's avatar
Suraj Patil committed
17
18

import torch
19
from packaging import version
20
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
Suraj Patil's avatar
Suraj Patil committed
21

22
from ...configuration_utils import FrozenDict
23
24
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
Suraj Patil's avatar
Suraj Patil committed
25
from ...models import AutoencoderKL, UNet2DConditionModel
26
from ...models.lora import adjust_lora_scale_text_encoder
Kashif Rasul's avatar
Kashif Rasul committed
27
from ...schedulers import KarrasDiffusionSchedulers
28
29
30
31
32
33
34
35
from ...utils import (
    USE_PEFT_BACKEND,
    deprecate,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
)
Dhruv Nair's avatar
Dhruv Nair committed
36
from ...utils.torch_utils import randn_tensor
37
from ..pipeline_utils import DiffusionPipeline
38
from .pipeline_output import StableDiffusionPipelineOutput
Suraj Patil's avatar
Suraj Patil committed
39
from .safety_checker import StableDiffusionSafetyChecker
Suraj Patil's avatar
Suraj Patil committed
40
41


42
43
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

44
45
46
47
48
49
50
51
52
53
54
55
56
57
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionPipeline

        >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> image = pipe(prompt).images[0]
        ```
"""

58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """
    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg


73
74
75
class StableDiffusionPipeline(
    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
76
77
78
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

79
80
    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).
81

82
83
84
85
86
    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
87
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
1lint's avatar
1lint committed
88

89
90
    Args:
        vae ([`AutoencoderKL`]):
91
92
93
94
95
96
97
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
98
        scheduler ([`SchedulerMixin`]):
99
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
100
101
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
102
            Classification module that estimates whether generated images could be considered offensive or harmful.
103
104
105
106
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
107
    """
108

109
    model_cpu_offload_seq = "text_encoder->unet->vae"
110
    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
111
    _exclude_from_cpu_offload = ["safety_checker"]
112
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
113

Suraj Patil's avatar
Suraj Patil committed
114
115
116
117
118
119
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
Kashif Rasul's avatar
Kashif Rasul committed
120
        scheduler: KarrasDiffusionSchedulers,
Suraj Patil's avatar
Suraj Patil committed
121
        safety_checker: StableDiffusionSafetyChecker,
122
        feature_extractor: CLIPImageProcessor,
123
        image_encoder: CLIPVisionModelWithProjection = None,
124
        requires_safety_checker: bool = True,
Suraj Patil's avatar
Suraj Patil committed
125
126
    ):
        super().__init__()
127
128

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
129
            deprecation_message = (
130
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
Yuta Hayashibe's avatar
Yuta Hayashibe committed
131
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
132
133
134
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
135
                " file"
136
            )
137
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
138
139
140
141
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

142
143
144
145
146
147
148
149
150
151
152
153
154
        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)

155
        if safety_checker is None and requires_safety_checker:
156
            logger.warning(
157
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
158
159
160
161
162
163
164
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

165
166
167
168
169
170
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

171
172
173
174
175
176
177
        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
Pedro Cuenca's avatar
Pedro Cuenca committed
178
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
179
180
181
182
183
184
185
186
187
188
189
190
191
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

Suraj Patil's avatar
Suraj Patil committed
192
193
194
195
196
197
198
199
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
200
            image_encoder=image_encoder,
Suraj Patil's avatar
Suraj Patil committed
201
        )
Patrick von Platen's avatar
Patrick von Platen committed
202
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
203
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
204
        self.register_to_config(requires_safety_checker=requires_safety_checker)
Suraj Patil's avatar
Suraj Patil committed
205

206
207
    def enable_vae_slicing(self):
        r"""
208
209
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
210
211
212
213
214
        """
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        r"""
215
        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
216
217
218
219
        computing decoding in one step.
        """
        self.vae.disable_slicing()

220
221
    def enable_vae_tiling(self):
        r"""
222
223
224
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.
225
226
227
228
229
        """
        self.vae.enable_tiling()

    def disable_vae_tiling(self):
        r"""
230
        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
231
232
233
234
        computing decoding in one step.
        """
        self.vae.disable_tiling()

235
236
237
238
239
240
241
242
243
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
244
        lora_scale: Optional[float] = None,
245
        **kwargs,
246
247
248
249
250
251
252
253
254
255
256
257
258
    ):
        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

        prompt_embeds_tuple = self.encode_prompt(
            prompt=prompt,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
259
            **kwargs,
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        )

        # concatenate for backwards comp
        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

        return prompt_embeds

    def encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        lora_scale: Optional[float] = None,
277
        clip_skip: Optional[int] = None,
278
    ):
279
280
281
282
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
283
            prompt (`str` or `List[str]`, *optional*):
284
285
286
287
288
289
290
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
291
            negative_prompt (`str` or `List[str]`, *optional*):
292
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
293
294
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
295
296
297
298
299
300
301
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
302
            lora_scale (`float`, *optional*):
303
304
305
306
                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
307
        """
308
309
310
311
312
        # set lora scale so that monkey patched LoRA
        # function of text encoder can correctly access it
        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
            self._lora_scale = lora_scale

313
            # dynamically adjust the LoRA scale
314
            if not USE_PEFT_BACKEND:
315
316
317
                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
            else:
                scale_lora_layers(self.text_encoder, lora_scale)
318

319
320
321
322
323
324
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]
325

326
        if prompt_embeds is None:
327
328
329
330
            # textual inversion: procecss multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

331
332
333
334
335
336
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
337
            )
338
339
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
Patrick von Platen's avatar
Patrick von Platen committed
340

341
342
343
344
345
346
347
348
349
350
            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )
Patrick von Platen's avatar
Patrick von Platen committed
351

352
353
354
355
            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            if clip_skip is None:
                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
                prompt_embeds = prompt_embeds[0]
            else:
                prompt_embeds = self.text_encoder(
                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
                )
                # Access the `hidden_states` first, that contains a tuple of
                # all the hidden states from the encoder layers. Then index into
                # the tuple to access the hidden states from the desired layer.
                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
                # We also need to apply the final LayerNorm here to not mess with the
                # representations. The `last_hidden_states` that we typically use for
                # obtaining the final prompt representations passes through the LayerNorm
                # layer.
                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
373

374
375
376
377
378
379
380
381
        if self.text_encoder is not None:
            prompt_embeds_dtype = self.text_encoder.dtype
        elif self.unet is not None:
            prompt_embeds_dtype = self.unet.dtype
        else:
            prompt_embeds_dtype = prompt_embeds.dtype

        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
382
383

        bs_embed, seq_len, _ = prompt_embeds.shape
384
        # duplicate text embeddings for each generation per prompt, using mps friendly method
385
386
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
387
388

        # get unconditional embeddings for classifier free guidance
389
        if do_classifier_free_guidance and negative_prompt_embeds is None:
390
391
392
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
393
            elif prompt is not None and type(prompt) is not type(negative_prompt):
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

409
410
411
412
            # textual inversion: procecss multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

413
            max_length = prompt_embeds.shape[1]
414
415
416
417
418
419
420
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )
Patrick von Platen's avatar
Patrick von Platen committed
421
422
423
424
425
426

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

427
            negative_prompt_embeds = self.text_encoder(
Patrick von Platen's avatar
Patrick von Platen committed
428
429
430
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
431
            negative_prompt_embeds = negative_prompt_embeds[0]
432

433
        if do_classifier_free_guidance:
434
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
435
436
            seq_len = negative_prompt_embeds.shape[1]

437
            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
438
439
440

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
441

442
        if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
443
            # Retrieve the original scale by scaling back the LoRA layers
444
            unscale_lora_layers(self.text_encoder, lora_scale)
445

446
        return prompt_embeds, negative_prompt_embeds
447

448
449
450
451
452
453
454
455
456
457
458
459
460
    def encode_image(self, image, device, num_images_per_prompt):
        dtype = next(self.image_encoder.parameters()).dtype

        if not isinstance(image, torch.Tensor):
            image = self.feature_extractor(image, return_tensors="pt").pixel_values

        image = image.to(device=device, dtype=dtype)
        image_embeds = self.image_encoder(image).image_embeds
        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

        uncond_image_embeds = torch.zeros_like(image_embeds)
        return image_embeds, uncond_image_embeds

461
    def run_safety_checker(self, image, device, dtype):
462
463
464
465
466
467
468
469
        if self.safety_checker is None:
            has_nsfw_concept = None
        else:
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
470
471
472
473
474
475
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        return image, has_nsfw_concept

    def decode_latents(self, latents):
476
477
478
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)

479
        latents = 1 / self.vae.config.scaling_factor * latents
480
        image = self.vae.decode(latents, return_dict=False)[0]
481
        image = (image / 2 + 0.5).clamp(0, 1)
482
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image

    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

503
504
505
506
507
508
509
510
511
    def check_inputs(
        self,
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
512
        callback_on_step_end_tensor_inputs=None,
513
    ):
514
515
516
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

517
        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
518
519
520
521
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )
522
523
524
525
526
527
        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )
528

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

555
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
Patrick von Platen's avatar
Patrick von Platen committed
556
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
557
558
559
560
561
562
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

563
        if latents is None:
564
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
565
566
567
568
569
570
571
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
        r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.

        The suffixes after the scaling factors represent the stages where they are being applied.

        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

        Args:
            s1 (`float`):
                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                mitigate "oversmoothing effect" in the enhanced denoising process.
            s2 (`float`):
                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                mitigate "oversmoothing effect" in the enhanced denoising process.
            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
        """
        if not hasattr(self, "unet"):
            raise ValueError("The pipeline must have `unet` for using FreeU.")
        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)

    def disable_freeu(self):
        """Disables the FreeU mechanism if enabled."""
        self.unet.disable_freeu()

Patrick von Platen's avatar
Patrick von Platen committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
        """
        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

        Args:
            timesteps (`torch.Tensor`):
                generate embedding vectors at these timesteps
            embedding_dim (`int`, *optional*, defaults to 512):
                dimension of the embeddings to generate
            dtype:
                data type of the generated embeddings

        Returns:
            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
        """
        assert len(w.shape) == 1
        w = w * 1000.0

        half_dim = embedding_dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
        emb = w.to(dtype)[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1))
        assert emb.shape == (w.shape[0], embedding_dim)
        return emb

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    @property
    def guidance_scale(self):
        return self._guidance_scale

    @property
    def guidance_rescale(self):
        return self._guidance_rescale

    @property
    def clip_skip(self):
        return self._clip_skip

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    @property
    def do_classifier_free_guidance(self):
Patrick von Platen's avatar
Patrick von Platen committed
644
        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
645
646
647
648
649
650
651
652
653

    @property
    def cross_attention_kwargs(self):
        return self._cross_attention_kwargs

    @property
    def num_timesteps(self):
        return self._num_timesteps

Suraj Patil's avatar
Suraj Patil committed
654
    @torch.no_grad()
655
    @replace_example_docstring(EXAMPLE_DOC_STRING)
Suraj Patil's avatar
Suraj Patil committed
656
657
    def __call__(
        self,
658
        prompt: Union[str, List[str]] = None,
659
660
        height: Optional[int] = None,
        width: Optional[int] = None,
661
662
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
663
        negative_prompt: Optional[Union[str, List[str]]] = None,
664
        num_images_per_prompt: Optional[int] = 1,
665
        eta: float = 0.0,
666
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
667
        latents: Optional[torch.FloatTensor] = None,
668
669
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
670
        ip_adapter_image: Optional[PipelineImageInput] = None,
Suraj Patil's avatar
Suraj Patil committed
671
        output_type: Optional[str] = "pil",
672
        return_dict: bool = True,
673
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
674
        guidance_rescale: float = 0.0,
675
        clip_skip: Optional[int] = None,
676
677
678
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        **kwargs,
Suraj Patil's avatar
Suraj Patil committed
679
    ):
680
        r"""
681
        The call function to the pipeline for generation.
682
683

        Args:
684
            prompt (`str` or `List[str]`, *optional*):
685
686
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
687
                The height in pixels of the generated image.
688
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
689
690
691
692
693
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
694
695
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
696
            negative_prompt (`str` or `List[str]`, *optional*):
697
698
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
699
700
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
701
            eta (`float`, *optional*, defaults to 0.0):
702
703
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
704
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
705
706
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
707
            latents (`torch.FloatTensor`, *optional*):
708
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
709
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
710
                tensor is generated by sampling using the supplied random `generator`.
711
            prompt_embeds (`torch.FloatTensor`, *optional*):
712
713
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
714
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
715
716
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
717
            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
718
            output_type (`str`, *optional*, defaults to `"pil"`):
719
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
720
721
722
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
723
            cross_attention_kwargs (`dict`, *optional*):
724
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
Patrick von Platen's avatar
Patrick von Platen committed
725
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
726
            guidance_rescale (`float`, *optional*, defaults to 0.0):
727
728
729
                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
                using zero terminal SNR.
730
731
732
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
733
734
735
736
737
738
739
740
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
Steven Liu's avatar
Steven Liu committed
741
                `._callback_tensor_inputs` attribute of your pipeline class.
742

743
744
        Examples:

745
        Returns:
746
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
747
748
749
750
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
751
        """
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768

        callback = kwargs.pop("callback", None)
        callback_steps = kwargs.pop("callback_steps", None)

        if callback is not None:
            deprecate(
                "callback",
                "1.0.0",
                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
            )
        if callback_steps is not None:
            deprecate(
                "callback_steps",
                "1.0.0",
                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
            )

769
        # 0. Default height and width to unet
Patrick von Platen's avatar
Patrick von Platen committed
770
771
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
772
        # to deal with lora scaling and other possible forward hooks
Suraj Patil's avatar
Suraj Patil committed
773

774
        # 1. Check inputs. Raise error if not correct
775
        self.check_inputs(
776
777
778
779
780
781
782
783
            prompt,
            height,
            width,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
            callback_on_step_end_tensor_inputs,
784
        )
785

786
787
788
789
790
        self._guidance_scale = guidance_scale
        self._guidance_rescale = guidance_rescale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs

791
        # 2. Define call parameters
792
793
794
795
796
797
798
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

Anton Lozhkov's avatar
Anton Lozhkov committed
799
        device = self._execution_device
Suraj Patil's avatar
Suraj Patil committed
800

801
        # 3. Encode input prompt
802
803
804
        lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )
805

806
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
807
808
809
            prompt,
            device,
            num_images_per_prompt,
810
            self.do_classifier_free_guidance,
811
812
813
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
814
            lora_scale=lora_scale,
815
            clip_skip=self.clip_skip,
816
        )
817

818
819
820
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
821
        if self.do_classifier_free_guidance:
822
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
823

824
825
826
827
828
        if ip_adapter_image is not None:
            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
            if self.do_classifier_free_guidance:
                image_embeds = torch.cat([negative_image_embeds, image_embeds])

829
        # 4. Prepare timesteps
Anton Lozhkov's avatar
Anton Lozhkov committed
830
        self.scheduler.set_timesteps(num_inference_steps, device=device)
831
832
833
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
834
        num_channels_latents = self.unet.config.in_channels
835
836
837
838
839
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
840
            prompt_embeds.dtype,
841
842
843
844
            device,
            generator,
            latents,
        )
Suraj Patil's avatar
Suraj Patil committed
845

846
847
        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
hlky's avatar
hlky committed
848

849
850
851
852
        # 6.1 Add image embeds for IP-Adapter
        added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

        # 6.2 Optionally get Guidance Scale Embedding
Patrick von Platen's avatar
Patrick von Platen committed
853
854
855
856
857
858
859
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

860
        # 7. Denoising loop
861
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
862
        self._num_timesteps = len(timesteps)
863
864
865
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
866
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
867
868
869
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
870
871
872
873
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
Patrick von Platen's avatar
Patrick von Platen committed
874
                    timestep_cond=timestep_cond,
875
                    cross_attention_kwargs=self.cross_attention_kwargs,
876
                    added_cond_kwargs=added_cond_kwargs,
877
878
                    return_dict=False,
                )[0]
879
880

                # perform guidance
881
                if self.do_classifier_free_guidance:
882
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
883
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
884

885
                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
886
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
887
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
888

889
                # compute the previous noisy sample x_t -> x_t-1
890
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
891

892
893
894
895
896
897
898
899
900
901
                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

902
                # call the callback, if provided
903
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
904
905
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
906
907
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)
908

909
        if not output_type == "latent":
Will Berman's avatar
Will Berman committed
910
911
912
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                0
            ]
913
914
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        else:
915
916
            image = latents
            has_nsfw_concept = None
Suraj Patil's avatar
Suraj Patil committed
917

918
919
        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
920
        else:
921
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
922

923
        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
Suraj Patil's avatar
Suraj Patil committed
924

925
926
        # Offload all models
        self.maybe_free_model_hooks()
927

928
929
930
931
        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)