Unverified Commit 3768d4d7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] refactor encode_prompt (#4617)



* refactoring of encode_prompt()

* better handling of device.

* fix: device determination

* fix: device determination 2

* handle num_images_per_prompt

* revert changes in loaders.py and give birth to encode_prompt().

* minor refactoring for encode_prompt()/

* make backward compatible.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix: concatenation of the neg and pos embeddings.

* incorporate encode_prompt() in test_stable_diffusion.py

* turn it into big PR.

* make it bigger

* gligen fixes.

* more fixes to fligen

* _encode_prompt -> encode_prompt in tests

* first batch

* second batch

* fix blasphemous mistake

* fix

* fix: hopefully for the final time.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8ccb6194
...@@ -24,7 +24,7 @@ from ...image_processor import VaeImageProcessor ...@@ -24,7 +24,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -171,6 +171,37 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -171,6 +171,37 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -315,12 +346,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -315,12 +346,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -549,7 +575,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -549,7 +575,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
do_self_attention_guidance = sag_scale > 0.0 do_self_attention_guidance = sag_scale > 0.0
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
...@@ -558,6 +584,11 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -558,6 +584,11 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -192,6 +192,37 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -192,6 +192,37 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -336,12 +367,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -336,12 +367,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -629,7 +655,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -629,7 +655,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
) )
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
...@@ -639,6 +665,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -639,6 +665,11 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess image # 4. Preprocess image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
......
...@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
...@@ -295,6 +302,37 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -295,6 +302,37 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -439,12 +477,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -439,12 +477,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -819,7 +852,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -819,7 +852,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
) )
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt, prompt=prompt,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
...@@ -829,6 +862,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -829,6 +862,11 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 9. Prepare image embeddings # 9. Prepare image embeddings
image_embeds = self.noise_image_embeddings( image_embeds = self.noise_image_embeddings(
......
...@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
...@@ -196,6 +196,92 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -196,6 +196,92 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
def _encode_image(
self,
image,
device,
batch_size,
num_images_per_prompt,
do_classifier_free_guidance,
noise_level,
generator,
image_embeds,
):
dtype = next(self.image_encoder.parameters()).dtype
if isinstance(image, PIL.Image.Image):
# the image embedding should repeated so it matches the total batch size of the prompt
repeat_by = batch_size
else:
# assume the image input is already properly batched and just needs to be repeated so
# it matches the num_images_per_prompt.
#
# NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
# `image_embeds`. If those happen to be common use cases, let's think harder about
# what the expected dimensions of inputs should be and how we handle the encoding.
repeat_by = num_images_per_prompt
if image_embeds is None:
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = self.noise_image_embeddings(
image_embeds=image_embeds,
noise_level=noise_level,
generator=generator,
)
# duplicate image embeddings for each generation per prompt, using mps friendly method
image_embeds = image_embeds.unsqueeze(1)
bs_embed, seq_len, _ = image_embeds.shape
image_embeds = image_embeds.repeat(1, repeat_by, 1)
image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
image_embeds = image_embeds.squeeze(1)
if do_classifier_free_guidance:
negative_prompt_embeds = torch.zeros_like(image_embeds)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -340,67 +426,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -340,67 +426,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
def _encode_image(
self,
image,
device,
batch_size,
num_images_per_prompt,
do_classifier_free_guidance,
noise_level,
generator,
image_embeds,
):
dtype = next(self.image_encoder.parameters()).dtype
if isinstance(image, PIL.Image.Image):
# the image embedding should repeated so it matches the total batch size of the prompt
repeat_by = batch_size
else:
# assume the image input is already properly batched and just needs to be repeated so
# it matches the num_images_per_prompt.
#
# NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
# `image_embeds`. If those happen to be common use cases, let's think harder about
# what the expected dimensions of inputs should be and how we handle the encoding.
repeat_by = num_images_per_prompt
if image_embeds is None:
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = self.noise_image_embeddings(
image_embeds=image_embeds,
noise_level=noise_level,
generator=generator,
)
# duplicate image embeddings for each generation per prompt, using mps friendly method
image_embeds = image_embeds.unsqueeze(1)
bs_embed, seq_len, _ = image_embeds.shape
image_embeds = image_embeds.repeat(1, repeat_by, 1)
image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
image_embeds = image_embeds.squeeze(1)
if do_classifier_free_guidance:
negative_prompt_embeds = torch.zeros_like(image_embeds)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -718,7 +744,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -718,7 +744,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
) )
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt, prompt=prompt,
device=device, device=device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
...@@ -728,6 +754,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -728,6 +754,11 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Encoder input image # 4. Encoder input image
noise_level = torch.tensor([noise_level], device=device) noise_level = torch.tensor([noise_level], device=device)
......
...@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -29,6 +29,7 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
BaseOutput, BaseOutput,
deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
...@@ -255,6 +256,37 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -255,6 +256,37 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -399,12 +431,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -399,12 +431,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
...@@ -692,7 +719,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -692,7 +719,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
...@@ -701,6 +728,11 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -701,6 +728,11 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -23,6 +23,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -23,6 +23,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
...@@ -182,6 +183,37 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -182,6 +183,37 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -326,12 +358,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -326,12 +358,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
def decode_latents(self, latents): def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
...@@ -561,7 +588,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -561,7 +588,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
) )
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
...@@ -571,6 +598,11 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -571,6 +598,11 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -24,6 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -24,6 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
...@@ -244,6 +245,37 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -244,6 +245,37 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None, lora_scale: Optional[float] = None,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -388,12 +420,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -388,12 +420,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -640,7 +667,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -640,7 +667,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
) )
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
...@@ -650,6 +677,11 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -650,6 +677,11 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=text_encoder_lora_scale,
) )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess video # 4. Preprocess video
video = preprocess_video(video) video = preprocess_video(video)
......
...@@ -584,21 +584,27 @@ class StableDiffusionPipelineFastTests( ...@@ -584,21 +584,27 @@ class StableDiffusionPipelineFastTests(
prompt = 25 * "@" prompt = 25 * "@"
with CaptureLogger(logger) as cap_logger_3: with CaptureLogger(logger) as cap_logger_3:
text_embeddings_3 = sd_pipe._encode_prompt( negative_text_embeddings_3, text_embeddings_3 = sd_pipe.encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
if negative_text_embeddings_3 is not None:
text_embeddings_3 = torch.cat([negative_text_embeddings_3, text_embeddings_3])
prompt = 100 * "@" prompt = 100 * "@"
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
text_embeddings = sd_pipe._encode_prompt( negative_text_embeddings, text_embeddings = sd_pipe.encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
if negative_text_embeddings is not None:
text_embeddings = torch.cat([negative_text_embeddings, text_embeddings])
negative_prompt = "Hello" negative_prompt = "Hello"
with CaptureLogger(logger) as cap_logger_2: with CaptureLogger(logger) as cap_logger_2:
text_embeddings_2 = sd_pipe._encode_prompt( negative_text_embeddings_2, text_embeddings_2 = sd_pipe.encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
if negative_text_embeddings_2 is not None:
text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2])
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
assert text_embeddings.shape[1] == 77 assert text_embeddings.shape[1] == 77
......
...@@ -245,21 +245,27 @@ class StableDiffusion2PipelineFastTests( ...@@ -245,21 +245,27 @@ class StableDiffusion2PipelineFastTests(
prompt = 25 * "@" prompt = 25 * "@"
with CaptureLogger(logger) as cap_logger_3: with CaptureLogger(logger) as cap_logger_3:
text_embeddings_3 = sd_pipe._encode_prompt( text_embeddings_3, negeative_text_embeddings_3 = sd_pipe.encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
if negeative_text_embeddings_3 is not None:
text_embeddings_3 = torch.cat([negeative_text_embeddings_3, text_embeddings_3])
prompt = 100 * "@" prompt = 100 * "@"
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
text_embeddings = sd_pipe._encode_prompt( text_embeddings, negative_embeddings = sd_pipe.encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
if negative_embeddings is not None:
text_embeddings = torch.cat([negative_embeddings, text_embeddings])
negative_prompt = "Hello" negative_prompt = "Hello"
with CaptureLogger(logger) as cap_logger_2: with CaptureLogger(logger) as cap_logger_2:
text_embeddings_2 = sd_pipe._encode_prompt( text_embeddings_2, negative_text_embeddings_2 = sd_pipe.encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
if negative_text_embeddings_2 is not None:
text_embeddings_2 = torch.cat([negative_text_embeddings_2, text_embeddings_2])
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
assert text_embeddings.shape[1] == 77 assert text_embeddings.shape[1] == 77
......
...@@ -251,7 +251,10 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase): ...@@ -251,7 +251,10 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
image = output.images image = output.images
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False) prompt_embeds, negative_prompt_embeds = sd_pipe.encode_prompt(prompt, device, 1, False)
if negative_prompt_embeds is not None:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
image_from_prompt_embeds = sd_pipe( image_from_prompt_embeds = sd_pipe(
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
image=[low_res_image], image=[low_res_image],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment