pipeline_alt_diffusion.py 33.4 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
16
from typing import Any, Callable, Dict, List, Optional, Union
Patrick von Platen's avatar
Patrick von Platen committed
17
18

import torch
19
from packaging import version
Patrick von Platen's avatar
Patrick von Platen committed
20
21
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer

22
23
from diffusers.utils import is_accelerate_available

Patrick von Platen's avatar
Patrick von Platen committed
24
25
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
Kashif Rasul's avatar
Kashif Rasul committed
26
from ...schedulers import KarrasDiffusionSchedulers
27
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
28
from ..pipeline_utils import DiffusionPipeline
Patrick von Platen's avatar
Patrick von Platen committed
29
30
31
32
33
34
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import AltDiffusionPipeline

        >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")

        >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap"
        >>> prompt = "榛戞殫绮剧伒鍏富锛岄潪甯歌缁嗭紝骞绘兂锛岄潪甯歌缁嗭紝鏁板瓧缁樼敾锛屾蹇佃壓鏈紝鏁忛攼鐨勭劍鐐癸紝鎻掑浘"
        >>> image = pipe(prompt).images[0]
        ```
"""

Patrick von Platen's avatar
Patrick von Platen committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline):
    r"""
    Pipeline for text-to-image generation using Alt Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`RobertaSeriesModelWithTransformation`]):
            Frozen text-encoder. Alt Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`XLMRobertaTokenizer`):
            Tokenizer of class
            [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
        feature_extractor ([`CLIPFeatureExtractor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """
79
    _optional_components = ["safety_checker", "feature_extractor"]
Patrick von Platen's avatar
Patrick von Platen committed
80
81
82
83
84
85
86

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: RobertaSeriesModelWithTransformation,
        tokenizer: XLMRobertaTokenizer,
        unet: UNet2DConditionModel,
Kashif Rasul's avatar
Kashif Rasul committed
87
        scheduler: KarrasDiffusionSchedulers,
Patrick von Platen's avatar
Patrick von Platen committed
88
89
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPFeatureExtractor,
90
        requires_safety_checker: bool = True,
Patrick von Platen's avatar
Patrick von Platen committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    ):
        super().__init__()

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)

121
        if safety_checker is None and requires_safety_checker:
122
            logger.warning(
Patrick von Platen's avatar
Patrick von Platen committed
123
124
125
126
127
128
129
130
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

131
132
133
134
135
136
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

137
138
139
140
141
142
143
        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
Pedro Cuenca's avatar
Pedro Cuenca committed
144
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
145
146
147
148
149
150
151
152
153
154
155
156
157
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

Patrick von Platen's avatar
Patrick von Platen committed
158
159
160
161
162
163
164
165
166
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
Patrick von Platen's avatar
Patrick von Platen committed
167
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
168
        self.register_to_config(requires_safety_checker=requires_safety_checker)
Patrick von Platen's avatar
Patrick von Platen committed
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def enable_vae_slicing(self):
        r"""
        Enable sliced VAE decoding.

        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
        steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        r"""
        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_slicing()

Patrick von Platen's avatar
Patrick von Platen committed
186
    def enable_sequential_cpu_offload(self, gpu_id=0):
Patrick von Platen's avatar
Patrick von Platen committed
187
188
189
190
191
192
193
194
195
196
        r"""
        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
        """
        if is_accelerate_available():
            from accelerate import cpu_offload
        else:
            raise ImportError("Please install accelerate via `pip install accelerate`")

Patrick von Platen's avatar
Patrick von Platen committed
197
        device = torch.device(f"cuda:{gpu_id}")
Patrick von Platen's avatar
Patrick von Platen committed
198

199
        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
200
            cpu_offload(cpu_offloaded_model, device)
Patrick von Platen's avatar
Patrick von Platen committed
201

202
        if self.safety_checker is not None:
Patrick von Platen's avatar
Patrick von Platen committed
203
            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
204

Patrick von Platen's avatar
Patrick von Platen committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    @property
    def _execution_device(self):
        r"""
        Returns the device on which the pipeline's models will be executed. After calling
        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
        hooks.
        """
        if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
            return self.device
        for module in self.unet.modules():
            if (
                hasattr(module, "_hf_hook")
                and hasattr(module._hf_hook, "execution_device")
                and module._hf_hook.execution_device is not None
            ):
                return torch.device(module._hf_hook.execution_device)
        return self.device

223
224
225
226
227
228
229
230
231
232
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
233
234
235
236
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
237
             prompt (`str` or `List[str]`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
238
239
240
241
242
243
244
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
245
            negative_prompt (`str` or `List[str]`, *optional*):
246
247
248
249
250
251
252
253
254
255
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
Patrick von Platen's avatar
Patrick von Platen committed
256
        """
257
258
259
260
261
262
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]
Patrick von Platen's avatar
Patrick von Platen committed
263

264
265
266
267
268
269
270
        if prompt_embeds is None:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
Patrick von Platen's avatar
Patrick von Platen committed
271
            )
272
273
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
Patrick von Platen's avatar
Patrick von Platen committed
274

275
276
277
278
279
280
281
282
283
284
            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )
Patrick von Platen's avatar
Patrick von Platen committed
285

286
287
288
289
            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None
Patrick von Platen's avatar
Patrick von Platen committed
290

291
292
293
294
295
296
297
298
299
            prompt_embeds = self.text_encoder(
                text_input_ids.to(device),
                attention_mask=attention_mask,
            )
            prompt_embeds = prompt_embeds[0]

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
Patrick von Platen's avatar
Patrick von Platen committed
300
        # duplicate text embeddings for each generation per prompt, using mps friendly method
301
302
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
Patrick von Platen's avatar
Patrick von Platen committed
303
304

        # get unconditional embeddings for classifier free guidance
305
        if do_classifier_free_guidance and negative_prompt_embeds is None:
Patrick von Platen's avatar
Patrick von Platen committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

325
            max_length = prompt_embeds.shape[1]
Patrick von Platen's avatar
Patrick von Platen committed
326
327
328
329
330
331
332
333
334
335
336
337
338
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

339
            negative_prompt_embeds = self.text_encoder(
Patrick von Platen's avatar
Patrick von Platen committed
340
341
342
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
343
            negative_prompt_embeds = negative_prompt_embeds[0]
Patrick von Platen's avatar
Patrick von Platen committed
344

345
        if do_classifier_free_guidance:
Patrick von Platen's avatar
Patrick von Platen committed
346
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
347
348
349
350
351
352
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
Patrick von Platen's avatar
Patrick von Platen committed
353
354
355
356

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
357
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
Patrick von Platen's avatar
Patrick von Platen committed
358

359
        return prompt_embeds
Patrick von Platen's avatar
Patrick von Platen committed
360
361
362
363
364
365
366
367
368
369
370
371

    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is not None:
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        else:
            has_nsfw_concept = None
        return image, has_nsfw_concept

    def decode_latents(self, latents):
372
        latents = 1 / self.vae.config.scaling_factor * latents
Patrick von Platen's avatar
Patrick von Platen committed
373
374
        image = self.vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
375
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
Patrick von Platen's avatar
Patrick von Platen committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image

    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (畏) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to 畏 in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

396
397
398
399
400
401
402
403
404
405
    def check_inputs(
        self,
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
406
407
408
409
410
411
412
413
414
415
416
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

Patrick von Platen's avatar
Patrick von Platen committed
443
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
Patrick von Platen's avatar
Patrick von Platen committed
444
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
445
446
447
448
449
450
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

Patrick von Platen's avatar
Patrick von Platen committed
451
        if latents is None:
452
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Patrick von Platen's avatar
Patrick von Platen committed
453
454
455
456
457
458
459
460
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents

    @torch.no_grad()
461
    @replace_example_docstring(EXAMPLE_DOC_STRING)
Patrick von Platen's avatar
Patrick von Platen committed
462
463
    def __call__(
        self,
464
        prompt: Union[str, List[str]] = None,
465
466
        height: Optional[int] = None,
        width: Optional[int] = None,
Patrick von Platen's avatar
Patrick von Platen committed
467
468
469
470
471
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
472
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Patrick von Platen's avatar
Patrick von Platen committed
473
        latents: Optional[torch.FloatTensor] = None,
474
475
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
Patrick von Platen's avatar
Patrick von Platen committed
476
477
478
479
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
480
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Patrick von Platen's avatar
Patrick von Platen committed
481
482
483
484
485
    ):
        r"""
        Function invoked when calling the pipeline for generation.

        Args:
486
487
488
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
Patrick von Platen's avatar
Patrick von Platen committed
489
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Patrick von Platen's avatar
Patrick von Platen committed
490
                The height in pixels of the generated image.
Patrick von Platen's avatar
Patrick von Platen committed
491
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Patrick von Platen's avatar
Patrick von Platen committed
492
493
494
495
496
497
498
499
500
501
502
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            negative_prompt (`str` or `List[str]`, *optional*):
503
504
505
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
Patrick von Platen's avatar
Patrick von Platen committed
506
507
508
509
510
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (畏) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
511
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
512
513
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
Patrick von Platen's avatar
Patrick von Platen committed
514
515
516
517
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
518
519
520
521
522
523
524
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
Patrick von Platen's avatar
Patrick von Platen committed
525
526
527
528
529
530
531
532
533
534
535
536
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
                plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
537
538
539
540
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Patrick von Platen's avatar
Patrick von Platen committed
541

542
543
        Examples:

Patrick von Platen's avatar
Patrick von Platen committed
544
545
546
547
548
549
550
        Returns:
            [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
            [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
            When returning a tuple, the first element is a list with the generated images, and the second element is a
            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content, according to the `safety_checker`.
        """
551
        # 0. Default height and width to unet
Patrick von Platen's avatar
Patrick von Platen committed
552
553
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
Patrick von Platen's avatar
Patrick von Platen committed
554
555

        # 1. Check inputs. Raise error if not correct
556
557
558
        self.check_inputs(
            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        )
Patrick von Platen's avatar
Patrick von Platen committed
559
560

        # 2. Define call parameters
561
562
563
564
565
566
567
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

Patrick von Platen's avatar
Patrick von Platen committed
568
569
570
571
572
573
574
        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
575
576
577
578
579
580
581
582
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
Patrick von Platen's avatar
Patrick von Platen committed
583
584
585
586
587
588
589
590
591
592
593
594
595
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
596
            prompt_embeds.dtype,
Patrick von Platen's avatar
Patrick von Platen committed
597
598
599
600
601
602
603
604
605
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
606
607
608
609
610
611
612
613
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
614
615
616
617
618
619
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                ).sample
620
621
622
623
624
625

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

626
627
                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
628
629

                # call the callback, if provided
630
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
631
632
633
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)
Patrick von Platen's avatar
Patrick von Platen committed
634

635
636
637
638
639
640
        if output_type == "latent":
            image = latents
            has_nsfw_concept = None
        elif output_type == "pil":
            # 8. Post-processing
            image = self.decode_latents(latents)
Patrick von Platen's avatar
Patrick von Platen committed
641

642
643
            # 9. Run safety checker
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
644

645
            # 10. Convert to PIL
Patrick von Platen's avatar
Patrick von Platen committed
646
            image = self.numpy_to_pil(image)
647
648
649
650
651
652
        else:
            # 8. Post-processing
            image = self.decode_latents(latents)

            # 9. Run safety checker
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
653
654
655
656
657

        if not return_dict:
            return (image, has_nsfw_concept)

        return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)