Unverified Commit f240a936 authored by Anatoly Belikov's avatar Anatoly Belikov Committed by GitHub
Browse files

handle lora scale and clip skip in lpw sd and sdxl community pipelines (#8988)



* handle lora scale and clip skip in lpw sd and sdxl

* use StableDiffusionLoraLoaderMixin

* use StableDiffusionXLLoraLoaderMixin

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 00d8d46e
...@@ -13,13 +13,17 @@ from diffusers.configuration_utils import FrozenDict ...@@ -13,13 +13,17 @@ from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import ( from diffusers.utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
scale_lora_layers,
unscale_lora_layers,
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
...@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings( ...@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
text_input: torch.Tensor, text_input: torch.Tensor,
chunk_length: int, chunk_length: int,
no_boseos_middle: Optional[bool] = True, no_boseos_middle: Optional[bool] = True,
clip_skip: Optional[int] = None,
): ):
""" """
When the length of tokens is a multiple of the capacity of the text encoder, When the length of tokens is a multiple of the capacity of the text encoder,
...@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings( ...@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens # cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1] text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0] if clip_skip is None:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
if no_boseos_middle: if no_boseos_middle:
if i == 0: if i == 0:
...@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings( ...@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding) text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1) text_embeddings = torch.concat(text_embeddings, axis=1)
else: else:
text_embeddings = pipe.text_encoder(text_input)[0] if clip_skip is None:
clip_skip = 0
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
return text_embeddings return text_embeddings
...@@ -242,6 +263,8 @@ def get_weighted_text_embeddings( ...@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False, no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False, skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False, skip_weighting: Optional[bool] = False,
clip_skip=None,
lora_scale=None,
): ):
r""" r"""
Prompts can be assigned with local weights using brackets. For example, Prompts can be assigned with local weights using brackets. For example,
...@@ -268,6 +291,16 @@ def get_weighted_text_embeddings( ...@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
skip_weighting (`bool`, *optional*, defaults to `False`): skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True. Skip the weighting. When the parsing is skipped, it is forced True.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
...@@ -334,10 +367,7 @@ def get_weighted_text_embeddings( ...@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
# get the embeddings # get the embeddings
text_embeddings = get_unweighted_text_embeddings( text_embeddings = get_unweighted_text_embeddings(
pipe, pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
) )
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None: if uncond_prompt is not None:
...@@ -346,6 +376,7 @@ def get_weighted_text_embeddings( ...@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
uncond_tokens, uncond_tokens,
pipe.tokenizer.model_max_length, pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip,
) )
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
...@@ -362,6 +393,11 @@ def get_weighted_text_embeddings( ...@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if uncond_prompt is not None: if uncond_prompt is not None:
return text_embeddings, uncond_embeddings return text_embeddings, uncond_embeddings
return text_embeddings, None return text_embeddings, None
...@@ -549,6 +585,8 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -549,6 +585,8 @@ class StableDiffusionLongPromptWeightingPipeline(
max_embeddings_multiples=3, max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -597,6 +635,8 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -597,6 +635,8 @@ class StableDiffusionLongPromptWeightingPipeline(
prompt=prompt, prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None, uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale,
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds = prompt_embeds1 prompt_embeds = prompt_embeds1
...@@ -790,6 +830,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -790,6 +830,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
...@@ -865,6 +906,9 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -865,6 +906,9 @@ class StableDiffusionLongPromptWeightingPipeline(
is_cancelled_callback (`Callable`, *optional*): is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled. `True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
...@@ -903,6 +947,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -903,6 +947,7 @@ class StableDiffusionLongPromptWeightingPipeline(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
...@@ -914,6 +959,8 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -914,6 +959,8 @@ class StableDiffusionLongPromptWeightingPipeline(
max_embeddings_multiples, max_embeddings_multiples,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale,
) )
dtype = prompt_embeds.dtype dtype = prompt_embeds.dtype
...@@ -1044,6 +1091,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -1044,6 +1091,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
...@@ -1101,6 +1149,9 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -1101,6 +1149,9 @@ class StableDiffusionLongPromptWeightingPipeline(
is_cancelled_callback (`Callable`, *optional*): is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled. `True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
...@@ -1135,6 +1186,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -1135,6 +1186,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict=return_dict, return_dict=return_dict,
callback=callback, callback=callback,
is_cancelled_callback=is_cancelled_callback, is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps, callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
) )
......
...@@ -25,21 +25,25 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor ...@@ -25,21 +25,25 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import ( from diffusers.loaders import (
FromSingleFileMixin, FromSingleFileMixin,
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import ( from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
...@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl( ...@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None,
): ):
""" """
This function can process long prompt with weights, no length limitation This function can process long prompt with weights, no length limitation
...@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl( ...@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
""" """
device = device or pipe._execution_device device = device or pipe._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)
if prompt_2: if prompt_2:
prompt = f"{prompt} {prompt_2}" prompt = f"{prompt} {prompt_2}"
...@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl( ...@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
...@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin, StableDiffusionMixin,
FromSingleFileMixin, FromSingleFileMixin,
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
): ):
r""" r"""
...@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline( ...@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods: The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
Args: Args:
...@@ -743,7 +776,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -743,7 +776,7 @@ class SDXLLongPromptWeightingPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1612,7 +1645,9 @@ class SDXLLongPromptWeightingPipeline( ...@@ -1612,7 +1645,9 @@ class SDXLLongPromptWeightingPipeline(
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None) lora_scale = (
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
)
negative_prompt = negative_prompt if negative_prompt is not None else "" negative_prompt = negative_prompt if negative_prompt is not None else ""
...@@ -1627,6 +1662,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -1627,6 +1662,7 @@ class SDXLLongPromptWeightingPipeline(
neg_prompt=negative_prompt, neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip, clip_skip=clip_skip,
lora_scale=lora_scale,
) )
dtype = prompt_embeds.dtype dtype = prompt_embeds.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment