"git@developer.sourcefind.cn:change/sglang.git" did not exist on "2f80bd9f0e1ff7e6fb19d2fe2ca3d1587bf1d0c7"
Unverified Commit f240a936 authored by Anatoly Belikov's avatar Anatoly Belikov Committed by GitHub
Browse files

handle lora scale and clip skip in lpw sd and sdxl community pipelines (#8988)



* handle lora scale and clip skip in lpw sd and sdxl

* use StableDiffusionLoraLoaderMixin

* use StableDiffusionXLLoraLoaderMixin

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 00d8d46e
...@@ -13,13 +13,17 @@ from diffusers.configuration_utils import FrozenDict ...@@ -13,13 +13,17 @@ from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import ( from diffusers.utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
scale_lora_layers,
unscale_lora_layers,
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
...@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings( ...@@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
text_input: torch.Tensor, text_input: torch.Tensor,
chunk_length: int, chunk_length: int,
no_boseos_middle: Optional[bool] = True, no_boseos_middle: Optional[bool] = True,
clip_skip: Optional[int] = None,
): ):
""" """
When the length of tokens is a multiple of the capacity of the text encoder, When the length of tokens is a multiple of the capacity of the text encoder,
...@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings( ...@@ -214,7 +219,20 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens # cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1] text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0] if clip_skip is None:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
if no_boseos_middle: if no_boseos_middle:
if i == 0: if i == 0:
...@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings( ...@@ -230,7 +248,10 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding) text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1) text_embeddings = torch.concat(text_embeddings, axis=1)
else: else:
text_embeddings = pipe.text_encoder(text_input)[0] if clip_skip is None:
clip_skip = 0
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
return text_embeddings return text_embeddings
...@@ -242,6 +263,8 @@ def get_weighted_text_embeddings( ...@@ -242,6 +263,8 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False, no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False, skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False, skip_weighting: Optional[bool] = False,
clip_skip=None,
lora_scale=None,
): ):
r""" r"""
Prompts can be assigned with local weights using brackets. For example, Prompts can be assigned with local weights using brackets. For example,
...@@ -268,6 +291,16 @@ def get_weighted_text_embeddings( ...@@ -268,6 +291,16 @@ def get_weighted_text_embeddings(
skip_weighting (`bool`, *optional*, defaults to `False`): skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True. Skip the weighting. When the parsing is skipped, it is forced True.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
...@@ -334,10 +367,7 @@ def get_weighted_text_embeddings( ...@@ -334,10 +367,7 @@ def get_weighted_text_embeddings(
# get the embeddings # get the embeddings
text_embeddings = get_unweighted_text_embeddings( text_embeddings = get_unweighted_text_embeddings(
pipe, pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
) )
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None: if uncond_prompt is not None:
...@@ -346,6 +376,7 @@ def get_weighted_text_embeddings( ...@@ -346,6 +376,7 @@ def get_weighted_text_embeddings(
uncond_tokens, uncond_tokens,
pipe.tokenizer.model_max_length, pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle, no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip,
) )
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
...@@ -362,6 +393,11 @@ def get_weighted_text_embeddings( ...@@ -362,6 +393,11 @@ def get_weighted_text_embeddings(
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if uncond_prompt is not None: if uncond_prompt is not None:
return text_embeddings, uncond_embeddings return text_embeddings, uncond_embeddings
return text_embeddings, None return text_embeddings, None
...@@ -549,6 +585,8 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -549,6 +585,8 @@ class StableDiffusionLongPromptWeightingPipeline(
max_embeddings_multiples=3, max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -597,6 +635,8 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -597,6 +635,8 @@ class StableDiffusionLongPromptWeightingPipeline(
prompt=prompt, prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None, uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale,
) )
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds = prompt_embeds1 prompt_embeds = prompt_embeds1
...@@ -790,6 +830,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -790,6 +830,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
...@@ -865,6 +906,9 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -865,6 +906,9 @@ class StableDiffusionLongPromptWeightingPipeline(
is_cancelled_callback (`Callable`, *optional*): is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled. `True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
...@@ -903,6 +947,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -903,6 +947,7 @@ class StableDiffusionLongPromptWeightingPipeline(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
...@@ -914,6 +959,8 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -914,6 +959,8 @@ class StableDiffusionLongPromptWeightingPipeline(
max_embeddings_multiples, max_embeddings_multiples,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale,
) )
dtype = prompt_embeds.dtype dtype = prompt_embeds.dtype
...@@ -1044,6 +1091,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -1044,6 +1091,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1, callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
): ):
...@@ -1101,6 +1149,9 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -1101,6 +1149,9 @@ class StableDiffusionLongPromptWeightingPipeline(
is_cancelled_callback (`Callable`, *optional*): is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled. `True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
...@@ -1135,6 +1186,7 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -1135,6 +1186,7 @@ class StableDiffusionLongPromptWeightingPipeline(
return_dict=return_dict, return_dict=return_dict,
callback=callback, callback=callback,
is_cancelled_callback=is_cancelled_callback, is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps, callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
) )
......
...@@ -25,21 +25,25 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor ...@@ -25,21 +25,25 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import ( from diffusers.loaders import (
FromSingleFileMixin, FromSingleFileMixin,
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import ( from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
) )
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
...@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl( ...@@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None,
): ):
""" """
This function can process long prompt with weights, no length limitation This function can process long prompt with weights, no length limitation
...@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl( ...@@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
""" """
device = device or pipe._execution_device device = device or pipe._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)
if prompt_2: if prompt_2:
prompt = f"{prompt} {prompt_2}" prompt = f"{prompt} {prompt_2}"
...@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl( ...@@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
...@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin, StableDiffusionMixin,
FromSingleFileMixin, FromSingleFileMixin,
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
): ):
r""" r"""
...@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline( ...@@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods: The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
Args: Args:
...@@ -743,7 +776,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -743,7 +776,7 @@ class SDXLLongPromptWeightingPipeline(
# set lora scale so that monkey patched LoRA # set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it # function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -1612,7 +1645,9 @@ class SDXLLongPromptWeightingPipeline( ...@@ -1612,7 +1645,9 @@ class SDXLLongPromptWeightingPipeline(
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None) lora_scale = (
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
)
negative_prompt = negative_prompt if negative_prompt is not None else "" negative_prompt = negative_prompt if negative_prompt is not None else ""
...@@ -1627,6 +1662,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -1627,6 +1662,7 @@ class SDXLLongPromptWeightingPipeline(
neg_prompt=negative_prompt, neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip, clip_skip=clip_skip,
lora_scale=lora_scale,
) )
dtype = prompt_embeds.dtype dtype = prompt_embeds.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment