Unverified Commit 74fd735e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add draft for lora text encoder scale (#3626)



* Add draft for lora text encoder scale

* Improve naming

* fix: training dreambooth lora script.

* Apply suggestions from code review

* Update examples/dreambooth/train_dreambooth_lora.py

* Apply suggestions from code review

* Apply suggestions from code review

* add lora mixin when fit

* add lora mixin when fit

* add lora mixin when fit

* fix more

* fix more

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2de9e2df
...@@ -260,6 +260,14 @@ pipe.load_lora_weights(lora_model_id) ...@@ -260,6 +260,14 @@ pipe.load_lora_weights(lora_model_id)
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
``` ```
<Tip>
If your LoRA parameters involve the UNet as well as the Text Encoder, then passing
`cross_attention_kwargs={"scale": 0.5}` will apply the `scale` value to both the UNet
and the Text Encoder.
</Tip>
Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is preferred to [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is preferred to [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
[`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] can handle the following situations: [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] can handle the following situations:
......
...@@ -852,6 +852,9 @@ class LoraLoaderMixin: ...@@ -852,6 +852,9 @@ class LoraLoaderMixin:
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
# set lora scale to a reasonable default
self._lora_scale = 1.0
if use_safetensors and not is_safetensors_available(): if use_safetensors and not is_safetensors_available():
raise ValueError( raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
...@@ -953,6 +956,12 @@ class LoraLoaderMixin: ...@@ -953,6 +956,12 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) warnings.warn(warn_message)
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
@property @property
def text_encoder_lora_attn_procs(self): def text_encoder_lora_attn_procs(self):
if hasattr(self, "_text_encoder_lora_attn_procs"): if hasattr(self, "_text_encoder_lora_attn_procs"):
...@@ -1000,7 +1009,8 @@ class LoraLoaderMixin: ...@@ -1000,7 +1009,8 @@ class LoraLoaderMixin:
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer): def make_new_forward(old_forward, lora_layer):
def new_forward(x): def new_forward(x):
return old_forward(x) + lora_layer(x) result = old_forward(x) + self.lora_scale * lora_layer(x)
return result
return new_forward return new_forward
......
...@@ -24,7 +24,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version ...@@ -24,7 +24,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
...@@ -52,7 +52,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -52,7 +52,7 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Alt Diffusion. Pipeline for text-to-image generation using Alt Diffusion.
...@@ -291,6 +291,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -291,6 +291,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -315,7 +316,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -315,7 +316,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -653,6 +661,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -653,6 +661,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -661,6 +672,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -661,6 +672,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
......
...@@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version ...@@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
...@@ -95,7 +95,7 @@ def preprocess(image): ...@@ -95,7 +95,7 @@ def preprocess(image):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Alt Diffusion. Pipeline for text-guided image to image generation using Alt Diffusion.
...@@ -302,6 +302,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -302,6 +302,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -326,7 +327,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -326,7 +327,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -706,6 +714,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -706,6 +714,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -714,6 +725,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -714,6 +725,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Preprocess image # 4. Preprocess image
......
...@@ -25,7 +25,7 @@ import torch.nn.functional as F ...@@ -25,7 +25,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -91,7 +91,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -91,7 +91,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
...@@ -291,6 +291,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -291,6 +291,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -315,7 +316,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -315,7 +316,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -838,6 +846,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -838,6 +846,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -846,6 +857,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -846,6 +857,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare image # 4. Prepare image
......
...@@ -25,7 +25,7 @@ import torch.nn.functional as F ...@@ -25,7 +25,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -117,7 +117,7 @@ def prepare_image(image): ...@@ -117,7 +117,7 @@ def prepare_image(image):
return image return image
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
...@@ -317,6 +317,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -317,6 +317,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -341,7 +342,14 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -341,7 +342,14 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -929,6 +937,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -929,6 +937,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -937,6 +948,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -937,6 +948,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare image # 4. Prepare image
image = self.image_processor.preprocess(image).to(dtype=torch.float32) image = self.image_processor.preprocess(image).to(dtype=torch.float32)
......
...@@ -26,7 +26,7 @@ import torch.nn.functional as F ...@@ -26,7 +26,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -223,7 +223,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False ...@@ -223,7 +223,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
return mask, masked_image return mask, masked_image
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
...@@ -434,6 +434,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -434,6 +434,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -458,7 +459,14 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -458,7 +459,14 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -1131,6 +1139,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1131,6 +1139,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
guess_mode = guess_mode or global_pool_conditions guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -1139,6 +1150,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -1139,6 +1150,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare image # 4. Prepare image
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import inspect import inspect
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
...@@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version ...@@ -26,7 +26,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
...@@ -126,7 +126,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): ...@@ -126,7 +126,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
return noise return noise
class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Stable Diffusion. Pipeline for text-guided image to image generation using Stable Diffusion.
...@@ -315,6 +315,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -315,6 +315,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -339,7 +340,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -339,7 +340,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -629,6 +637,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -629,6 +637,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -685,6 +694,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -685,6 +694,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
...@@ -705,12 +718,16 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -705,12 +718,16 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
source_prompt_embeds = self._encode_prompt( source_prompt_embeds = self._encode_prompt(
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
...@@ -764,7 +781,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -764,7 +781,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
dim=0, dim=0,
) )
concat_noise_pred = self.unet( concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds concat_latent_model_input,
t,
cross_attention_kwargs=cross_attention_kwargs,
encoder_hidden_states=concat_prompt_embeds,
).sample ).sample
# perform guidance # perform guidance
......
...@@ -294,6 +294,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -294,6 +294,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -318,7 +319,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -318,7 +319,14 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -654,6 +662,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -654,6 +662,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -662,6 +673,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -662,6 +673,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
......
...@@ -23,7 +23,7 @@ from torch.nn import functional as F ...@@ -23,7 +23,7 @@ from torch.nn import functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -306,6 +306,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -306,6 +306,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -330,7 +331,14 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -330,7 +331,14 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import contextlib import contextlib
import inspect import inspect
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
...@@ -183,6 +183,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -183,6 +183,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -207,7 +208,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -207,7 +208,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -546,6 +554,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -546,6 +554,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -606,6 +615,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -606,6 +615,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Examples: Examples:
...@@ -665,6 +678,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -665,6 +678,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -673,6 +689,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -673,6 +689,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Prepare depth mask # 4. Prepare depth mask
...@@ -711,9 +728,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -711,9 +728,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ noise_pred = self.unet(
0 latent_model_input,
] t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -487,6 +487,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -487,6 +487,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -511,7 +512,14 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -511,7 +512,14 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -1007,6 +1015,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -1007,6 +1015,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompts # 3. Encode input prompts
(cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
target_prompt_embeds = self._encode_prompt( target_prompt_embeds = self._encode_prompt(
target_prompt, target_prompt,
device, device,
...@@ -1458,6 +1467,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -1458,6 +1467,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -1466,6 +1478,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -1466,6 +1478,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Preprocess mask # 4. Preprocess mask
......
...@@ -309,6 +309,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -309,6 +309,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -333,7 +334,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -333,7 +334,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -714,6 +722,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -714,6 +722,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -722,6 +733,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -722,6 +733,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Preprocess image # 4. Preprocess image
......
...@@ -378,6 +378,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -378,6 +378,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -402,7 +403,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -402,7 +403,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -898,6 +906,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -898,6 +906,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -906,6 +917,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -906,6 +917,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. set timesteps # 4. set timesteps
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import inspect import inspect
import warnings import warnings
from typing import Callable, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
...@@ -304,6 +304,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -304,6 +304,7 @@ class StableDiffusionInpaintPipelineLegacy(
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -328,7 +329,14 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -328,7 +329,14 @@ class StableDiffusionInpaintPipelineLegacy(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -575,6 +583,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -575,6 +583,7 @@ class StableDiffusionInpaintPipelineLegacy(
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -639,6 +648,10 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -639,6 +648,10 @@ class StableDiffusionInpaintPipelineLegacy(
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
...@@ -665,6 +678,9 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -665,6 +678,9 @@ class StableDiffusionInpaintPipelineLegacy(
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
device, device,
...@@ -673,6 +689,7 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -673,6 +689,7 @@ class StableDiffusionInpaintPipelineLegacy(
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
) )
# 4. Preprocess image and mask # 4. Preprocess image and mask
...@@ -708,9 +725,13 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -708,9 +725,13 @@ class StableDiffusionInpaintPipelineLegacy(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ noise_pred = self.unet(
0 latent_model_input,
] t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser ...@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import get_sigmas_karras from k_diffusion.sampling import get_sigmas_karras
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
...@@ -210,6 +210,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -210,6 +210,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -234,7 +235,14 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -234,7 +235,14 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment