pipeline_stable_diffusion.py 35.3 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Suraj Patil's avatar
Suraj Patil committed
15
import inspect
16
from typing import Any, Callable, Dict, List, Optional, Union
Suraj Patil's avatar
Suraj Patil committed
17
18

import torch
19
from packaging import version
Suraj Patil's avatar
Suraj Patil committed
20
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Suraj Patil's avatar
Suraj Patil committed
21

22
from ...configuration_utils import FrozenDict
Suraj Patil's avatar
Suraj Patil committed
23
from ...models import AutoencoderKL, UNet2DConditionModel
Kashif Rasul's avatar
Kashif Rasul committed
24
from ...schedulers import KarrasDiffusionSchedulers
25
26
27
28
29
30
31
32
from ...utils import (
    deprecate,
    is_accelerate_available,
    is_accelerate_version,
    logging,
    randn_tensor,
    replace_example_docstring,
)
33
from ..pipeline_utils import DiffusionPipeline
34
from . import StableDiffusionPipelineOutput
Suraj Patil's avatar
Suraj Patil committed
35
from .safety_checker import StableDiffusionSafetyChecker
Suraj Patil's avatar
Suraj Patil committed
36
37


38
39
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

40
41
42
43
44
45
46
47
48
49
50
51
52
53
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionPipeline

        >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> image = pipe(prompt).images[0]
        ```
"""

54

Suraj Patil's avatar
Suraj Patil committed
55
class StableDiffusionPipeline(DiffusionPipeline):
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
74
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
75
76
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
77
            Classification module that estimates whether generated images could be considered offensive or harmful.
apolinario's avatar
apolinario committed
78
            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
79
80
81
        feature_extractor ([`CLIPFeatureExtractor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """
82
    _optional_components = ["safety_checker", "feature_extractor"]
83

Suraj Patil's avatar
Suraj Patil committed
84
85
86
87
88
89
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
Kashif Rasul's avatar
Kashif Rasul committed
90
        scheduler: KarrasDiffusionSchedulers,
Suraj Patil's avatar
Suraj Patil committed
91
92
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPFeatureExtractor,
93
        requires_safety_checker: bool = True,
Suraj Patil's avatar
Suraj Patil committed
94
95
    ):
        super().__init__()
96
97

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
98
            deprecation_message = (
99
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
Yuta Hayashibe's avatar
Yuta Hayashibe committed
100
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
101
102
103
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
104
                " file"
105
            )
106
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
107
108
109
110
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

111
112
113
114
115
116
117
118
119
120
121
122
123
        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)

124
        if safety_checker is None and requires_safety_checker:
125
            logger.warning(
126
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
127
128
129
130
131
132
133
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

134
135
136
137
138
139
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

140
141
142
143
144
145
146
        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
Pedro Cuenca's avatar
Pedro Cuenca committed
147
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
148
149
150
151
152
153
154
155
156
157
158
159
160
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

Suraj Patil's avatar
Suraj Patil committed
161
162
163
164
165
166
167
168
169
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
Patrick von Platen's avatar
Patrick von Platen committed
170
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
171
        self.register_to_config(requires_safety_checker=requires_safety_checker)
Suraj Patil's avatar
Suraj Patil committed
172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def enable_vae_slicing(self):
        r"""
        Enable sliced VAE decoding.

        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
        steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        r"""
        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_slicing()

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def enable_vae_tiling(self):
        r"""
        Enable tiled VAE decoding.

        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
        several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
        """
        self.vae.enable_tiling()

    def disable_vae_tiling(self):
        r"""
        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_tiling()

205
    def enable_sequential_cpu_offload(self, gpu_id=0):
206
207
208
209
        r"""
        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
210
211
        Note that offloading happens on a submodule basis. Memory savings are higher than with
        `enable_model_cpu_offload`, but performance is lower.
212
        """
213
        if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
214
215
            from accelerate import cpu_offload
        else:
216
            raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
217

218
        device = torch.device(f"cuda:{gpu_id}")
219

220
221
222
223
        if self.device.type != "cpu":
            self.to("cpu", silence_dtype_warnings=True)
            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)

224
        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
225
            cpu_offload(cpu_offloaded_model, device)
226

227
        if self.safety_checker is not None:
228
            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
229

230
231
232
233
234
235
236
237
238
239
    def enable_model_cpu_offload(self, gpu_id=0):
        r"""
        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
        """
        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
            from accelerate import cpu_offload_with_hook
        else:
240
            raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
241
242
243

        device = torch.device(f"cuda:{gpu_id}")

244
245
246
247
        if self.device.type != "cpu":
            self.to("cpu", silence_dtype_warnings=True)
            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)

248
249
250
251
252
253
254
255
256
257
        hook = None
        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

        if self.safety_checker is not None:
            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)

        # We'll offload the last model manually.
        self.final_offload_hook = hook

Anton Lozhkov's avatar
Anton Lozhkov committed
258
259
260
261
262
263
264
    @property
    def _execution_device(self):
        r"""
        Returns the device on which the pipeline's models will be executed. After calling
        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
        hooks.
        """
265
        if not hasattr(self.unet, "_hf_hook"):
Anton Lozhkov's avatar
Anton Lozhkov committed
266
267
268
269
270
271
272
273
274
275
            return self.device
        for module in self.unet.modules():
            if (
                hasattr(module, "_hf_hook")
                and hasattr(module._hf_hook, "execution_device")
                and module._hf_hook.execution_device is not None
            ):
                return torch.device(module._hf_hook.execution_device)
        return self.device

276
277
278
279
280
281
282
283
284
285
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    ):
286
287
288
289
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
290
             prompt (`str` or `List[str]`, *optional*):
291
292
293
294
295
296
297
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
298
            negative_prompt (`str` or `List[str]`, *optional*):
299
300
301
302
303
304
305
306
307
308
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
309
        """
310
311
312
313
314
315
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]
316

317
318
319
320
321
322
323
        if prompt_embeds is None:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
324
            )
325
326
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
Patrick von Platen's avatar
Patrick von Platen committed
327

328
329
330
331
332
333
334
335
336
337
            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )
Patrick von Platen's avatar
Patrick von Platen committed
338

339
340
341
342
            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None
343

344
345
346
347
348
349
350
351
352
            prompt_embeds = self.text_encoder(
                text_input_ids.to(device),
                attention_mask=attention_mask,
            )
            prompt_embeds = prompt_embeds[0]

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
353
        # duplicate text embeddings for each generation per prompt, using mps friendly method
354
355
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
356
357

        # get unconditional embeddings for classifier free guidance
358
        if do_classifier_free_guidance and negative_prompt_embeds is None:
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

378
            max_length = prompt_embeds.shape[1]
379
380
381
382
383
384
385
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )
Patrick von Platen's avatar
Patrick von Platen committed
386
387
388
389
390
391

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

392
            negative_prompt_embeds = self.text_encoder(
Patrick von Platen's avatar
Patrick von Platen committed
393
394
395
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
396
            negative_prompt_embeds = negative_prompt_embeds[0]
397

398
        if do_classifier_free_guidance:
399
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
400
401
402
403
404
405
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
406
407
408
409

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
410
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
411

412
        return prompt_embeds
413

414
415
416
417
418
419
420
421
422
423
424
    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is not None:
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        else:
            has_nsfw_concept = None
        return image, has_nsfw_concept

    def decode_latents(self, latents):
425
        latents = 1 / self.vae.config.scaling_factor * latents
426
427
        image = self.vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
428
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image

    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (畏) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to 畏 in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

449
450
451
452
453
454
455
456
457
458
    def check_inputs(
        self,
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
    ):
459
460
461
462
463
464
465
466
467
468
469
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

496
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
Patrick von Platen's avatar
Patrick von Platen committed
497
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
498
499
500
501
502
503
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

504
        if latents is None:
505
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
506
507
508
509
510
511
512
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents

Suraj Patil's avatar
Suraj Patil committed
513
    @torch.no_grad()
514
    @replace_example_docstring(EXAMPLE_DOC_STRING)
Suraj Patil's avatar
Suraj Patil committed
515
516
    def __call__(
        self,
517
        prompt: Union[str, List[str]] = None,
518
519
        height: Optional[int] = None,
        width: Optional[int] = None,
520
521
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
522
        negative_prompt: Optional[Union[str, List[str]]] = None,
523
        num_images_per_prompt: Optional[int] = 1,
524
        eta: float = 0.0,
525
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
526
        latents: Optional[torch.FloatTensor] = None,
527
528
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
Suraj Patil's avatar
Suraj Patil committed
529
        output_type: Optional[str] = "pil",
530
        return_dict: bool = True,
531
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
532
        callback_steps: int = 1,
533
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Suraj Patil's avatar
Suraj Patil committed
534
    ):
535
536
537
538
        r"""
        Function invoked when calling the pipeline for generation.

        Args:
539
540
541
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
Patrick von Platen's avatar
Patrick von Platen committed
542
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
543
                The height in pixels of the generated image.
Patrick von Platen's avatar
Patrick von Platen committed
544
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
545
546
547
548
549
550
551
552
553
554
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
555
            negative_prompt (`str` or `List[str]`, *optional*):
556
557
558
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
559
560
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
561
562
563
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (畏) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
564
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
565
566
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
567
568
569
570
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
571
572
573
574
575
576
577
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
578
579
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
580
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
581
582
583
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
584
585
586
587
588
589
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
590
            cross_attention_kwargs (`dict`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
591
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
592
593
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
594

595
596
        Examples:

597
        Returns:
598
599
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
600
601
602
603
            When returning a tuple, the first element is a list with the generated images, and the second element is a
            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content, according to the `safety_checker`.
        """
604
        # 0. Default height and width to unet
Patrick von Platen's avatar
Patrick von Platen committed
605
606
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
Suraj Patil's avatar
Suraj Patil committed
607

608
        # 1. Check inputs. Raise error if not correct
609
610
611
        self.check_inputs(
            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        )
612

613
        # 2. Define call parameters
614
615
616
617
618
619
620
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

Anton Lozhkov's avatar
Anton Lozhkov committed
621
        device = self._execution_device
Suraj Patil's avatar
Suraj Patil committed
622
623
624
625
626
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

627
        # 3. Encode input prompt
628
629
630
631
632
633
634
635
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
636
        )
637

638
        # 4. Prepare timesteps
Anton Lozhkov's avatar
Anton Lozhkov committed
639
        self.scheduler.set_timesteps(num_inference_steps, device=device)
640
641
642
643
644
645
646
647
648
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
649
            prompt_embeds.dtype,
650
651
652
653
            device,
            generator,
            latents,
        )
Suraj Patil's avatar
Suraj Patil committed
654

655
656
        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
hlky's avatar
hlky committed
657

658
        # 7. Denoising loop
659
660
661
662
663
664
665
666
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
667
668
669
670
671
672
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample
673
674
675
676
677
678

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

679
680
                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
681
682

                # call the callback, if provided
683
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
684
685
686
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)
687

688
689
690
691
692
693
        if output_type == "latent":
            image = latents
            has_nsfw_concept = None
        elif output_type == "pil":
            # 8. Post-processing
            image = self.decode_latents(latents)
Suraj Patil's avatar
Suraj Patil committed
694

695
696
            # 9. Run safety checker
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
Suraj Patil's avatar
Suraj Patil committed
697

698
            # 10. Convert to PIL
Suraj Patil's avatar
Suraj Patil committed
699
            image = self.numpy_to_pil(image)
700
701
702
703
704
705
        else:
            # 8. Post-processing
            image = self.decode_latents(latents)

            # 9. Run safety checker
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
Suraj Patil's avatar
Suraj Patil committed
706

707
708
709
710
        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

711
712
713
714
        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)