"vscode:/vscode.git/clone" did not exist on "81bdbb5e2a60fdfd4ee3c25ee2619adcd60f96ca"
Unverified Commit f240a936 authored by Anatoly Belikov's avatar Anatoly Belikov Committed by GitHub
Browse files

handle lora scale and clip skip in lpw sd and sdxl community pipelines (#8988)



* handle lora scale and clip skip in lpw sd and sdxl

* use StableDiffusionLoraLoaderMixin

* use StableDiffusionXLLoraLoaderMixin

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 00d8d46e
......@@ -13,13 +13,17 @@ from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
......@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
clip_skip: Optional[int] = None,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
......@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0]
if clip_skip is None:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
if no_boseos_middle:
if i == 0:
......@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
if clip_skip is None:
clip_skip = 0
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
return text_embeddings
......@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
lora_scale=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
......@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
......@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None:
......@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
......@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None
......@@ -549,6 +585,8 @@ class StableDiffusionLongPromptWeightingPipeline(
max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
......@@ -597,6 +635,8 @@ class StableDiffusionLongPromptWeightingPipeline(
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
if prompt_embeds is None:
prompt_embeds = prompt_embeds1
......@@ -790,6 +830,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
......@@ -865,6 +906,9 @@ class StableDiffusionLongPromptWeightingPipeline(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
......@@ -903,6 +947,7 @@ class StableDiffusionLongPromptWeightingPipeline(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
......@@ -914,6 +959,8 @@ class StableDiffusionLongPromptWeightingPipeline(
max_embeddings_multiples,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype
......@@ -1044,6 +1091,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
......@@ -1101,6 +1149,9 @@ class StableDiffusionLongPromptWeightingPipeline(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
......@@ -1135,6 +1186,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
)
......
......@@ -25,21 +25,25 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
......@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None,
):
"""
This function can process long prompt with weights, no length limitation
......@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
"""
device = device or pipe._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)
if prompt_2:
prompt = f"{prompt} {prompt_2}"
......@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
bs_embed * num_images_per_prompt, -1
)
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin,
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
):
r"""
......@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
Args:
......@@ -743,7 +776,7 @@ class SDXLLongPromptWeightingPipeline(
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
......@@ -1612,7 +1645,9 @@ class SDXLLongPromptWeightingPipeline(
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
lora_scale = (
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
)
negative_prompt = negative_prompt if negative_prompt is not None else ""
......@@ -1627,6 +1662,7 @@ class SDXLLongPromptWeightingPipeline(
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment