Unverified Commit 4a4cdd6b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Feat] Support SDXL Kohya-style LoRA (#4287)



* sdxl lora changes.

* better name replacement.

* better replacement.

* debugging

* debugging

* debugging

* debugging

* debugging

* remove print.

* print state dict keys.

* print

* distingisuih better

* debuggable.

* fxi: tyests

* fix: arg from training script.

* access from class.

* run style

* debug

* save intermediate

* some simplifications for SDXL LoRA

* styling

* unet config is not needed in diffusers format.

* fix: dynamic SGM block mapping for SDXL kohya loras (#4322)

* Use lora compatible layers for linear proj_in/proj_out (#4323)

* improve condition for using the sgm_diffusers mapping

* informative comment.

* load compatible keys and embedding layer maaping.

* Get SDXL 1.0 example lora to load

* simplify

* specif ranks and hidden sizes.

* better handling of k rank and hidden

* debug

* debug

* debug

* debug

* debug

* fix: alpha keys

* add check for handling LoRAAttnAddedKVProcessor

* sanity comment

* modifications for text encoder SDXL

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* up

* up

* up

* up

* up

* up

* unneeded comments.

* unneeded comments.

* kwargs for the other attention processors.

* kwargs for the other attention processors.

* debugging

* debugging

* debugging

* debugging

* improve

* debugging

* debugging

* more print

* Fix alphas

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up

* clean up.

* debugging

* fix: text

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarBatuhan Taskaya <batuhan@python.org>
parent b7b6d613
...@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: ...@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
lora_model_id = "sayakpaul/civitai-light-shadow-lora" lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors" lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
``` ```
\ No newline at end of file
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
Here are some example checkpoints we tried out:
* SDXL 0.9:
* https://civitai.com/models/22279?modelVersionId=118556
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
* https://civitai.com/models/108448/daiton-sdxl-test
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
* SDXL 1.0:
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
Here is an example of how to perform inference with these checkpoints in `diffusers`:
```python
from diffusers import DiffusionPipeline
import torch
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
generator = torch.manual_seed(2947883060)
num_inference_steps = 30
guidance_scale = 7
image = pipeline(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
generator=generator, guidance_scale=guidance_scale
).images[0]
image.save("Kamepan.png")
```
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
### Known limitations specific to the Kohya-styled LoRAs
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
\ No newline at end of file
...@@ -925,10 +925,10 @@ def main(args): ...@@ -925,10 +925,10 @@ def main(args):
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
) )
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
......
...@@ -825,13 +825,13 @@ def main(args): ...@@ -825,13 +825,13 @@ def main(args):
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
) )
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
) )
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import re
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
...@@ -56,7 +57,6 @@ UNET_NAME = "unet" ...@@ -56,7 +57,6 @@ UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TOTAL_EXAMPLE_KEYS = 5
TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
...@@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin: ...@@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alpha = kwargs.pop("network_alpha", None) network_alphas = kwargs.pop("network_alphas", None)
if use_safetensors and not is_safetensors_available(): if use_safetensors and not is_safetensors_available():
raise ValueError( raise ValueError(
...@@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin: ...@@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin:
attn_processors = {} attn_processors = {}
non_attn_lora_layers = [] non_attn_lora_layers = []
is_lora = all("lora" in k for k in state_dict.keys()) is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora: if is_lora:
...@@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin: ...@@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin:
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
lora_grouped_dict = defaultdict(dict) lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items(): mapped_network_alphas = {}
all_keys = list(state_dict.keys())
for key in all_keys:
value = state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value lora_grouped_dict[attn_processor_key][sub_key] = value
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
if len(state_dict) > 0:
raise ValueError(
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items(): for key, value_dict in lora_grouped_dict.items():
attn_processor = self attn_processor = self
for sub_key in key.split("."): for sub_key in key.split("."):
...@@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin: ...@@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin:
# or add_{k,v,q,out_proj}_proj_lora layers. # or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict: if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0] rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv): if isinstance(attn_processor, LoRACompatibleConv):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear): elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer( lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.out_features, rank, network_alpha attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
) )
else: else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
...@@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin: ...@@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin:
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict) lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora)) non_attn_lora_layers.append((attn_processor, lora))
continue
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else: else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] # To handle SDXL.
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): rank_mapping = {}
attn_processor_class = LoRAXFormersAttnProcessor hidden_size_mapping = {}
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]
rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
else:
attn_processor_class = (
LoRAAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAAttnProcessor
)
if attn_processor_class is not LoRAAttnAddedKVProcessor:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight"),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
q_rank=rank_mapping.get("to_q_lora.down.weight"),
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
v_rank=rank_mapping.get("to_v_lora.down.weight"),
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
out_rank=rank_mapping.get("to_out_lora.down.weight"),
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
# rank=rank_mapping.get("to_k_lora.down.weight", None),
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
)
else: else:
attn_processor_class = ( attn_processors[key] = attn_processor_class(
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor rank=rank_mapping.get("to_k_lora.down.weight", None),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
) )
attn_processors[key] = attn_processor_class( attn_processors[key].load_state_dict(value_dict)
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion: elif is_custom_diffusion:
custom_diffusion_grouped_dict = defaultdict(dict) custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items(): for key, value in state_dict.items():
...@@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin: ...@@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin:
# set ff layers # set ff layers
for target_module, lora_layer in non_attn_lora_layers: for target_module, lora_layer in non_attn_lora_layers:
if hasattr(target_module, "set_lora_layer"): target_module.set_lora_layer(lora_layer)
target_module.set_lora_layer(lora_layer) # It should raise an error if we don't have a set lora here
# if hasattr(target_module, "set_lora_layer"):
# target_module.set_lora_layer(lora_layer)
def save_attn_procs( def save_attn_procs(
self, self,
...@@ -873,11 +936,11 @@ class LoraLoaderMixin: ...@@ -873,11 +936,11 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
""" """
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, state_dict,
network_alpha=network_alpha, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
) )
...@@ -889,7 +952,7 @@ class LoraLoaderMixin: ...@@ -889,7 +952,7 @@ class LoraLoaderMixin:
**kwargs, **kwargs,
): ):
r""" r"""
Return state dict for lora weights Return state dict for lora weights and the network alphas.
<Tip warning={true}> <Tip warning={true}>
...@@ -950,6 +1013,7 @@ class LoraLoaderMixin: ...@@ -950,6 +1013,7 @@ class LoraLoaderMixin:
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available(): if use_safetensors and not is_safetensors_available():
...@@ -1011,53 +1075,158 @@ class LoraLoaderMixin: ...@@ -1011,53 +1075,158 @@ class LoraLoaderMixin:
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alphas = None
network_alpha = None if all(
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): (
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict) k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alpha return state_dict, network_alphas
@classmethod @classmethod
def load_lora_into_unet(cls, state_dict, network_alpha, unet): def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]
# Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in state_dict:
if "text" not in layer:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if "input_blocks" in layer:
input_block_ids.add(layer_id)
elif "middle_block" in layer:
middle_block_ids.add(layer_id)
elif "output_blocks" in layer:
output_block_ids.add(layer_id)
else:
raise ValueError("Checkpoint not supported")
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
for layer_id in input_block_ids
}
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
for layer_id in middle_block_ids
}
output_blocks = {
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
for layer_id in output_block_ids
}
# Rename keys accordingly
for i in input_block_ids:
block_id = (i - 1) // (unet_config.layers_per_block + 1)
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
for key in input_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in middle_block_ids:
key_part = None
if i == 0:
key_part = [inner_block_map[0], "0"]
elif i == 1:
key_part = [inner_block_map[1], "0"]
elif i == 2:
key_part = [inner_block_map[0], "1"]
else:
raise ValueError(f"Invalid middle block id {i}.")
for key in middle_blocks[i]:
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in output_block_ids:
block_id = i // (unet_config.layers_per_block + 1)
layer_in_block_id = i % (unet_config.layers_per_block + 1)
for key in output_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id]
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
if is_all_unet and len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
else:
# Remaining is the text encoder state dict.
for k, v in state_dict.items():
new_state_dict.update({k: v})
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
""" """
This will load the LoRA layers specified in `state_dict` into `unet` This will load the LoRA layers specified in `state_dict` into `unet`.
Parameters: Parameters:
state_dict (`dict`): state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers. encoder lora layers.
network_alpha (`float`): network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details. See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into. The UNet model to load the LoRA layers into.
""" """
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
unet_lora_state_dict = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Otherwise, we're dealing with the old format. This means the `state_dict` should only unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
# contain the module names of the `unet` as its keys WITHOUT any prefix. state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
elif not all(
key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys() if network_alphas is not None:
): alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
unet.load_attn_procs(state_dict, network_alpha=network_alpha) network_alphas = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) warnings.warn(warn_message)
# load loras into unet
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
@classmethod @classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0): def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -1065,7 +1234,7 @@ class LoraLoaderMixin: ...@@ -1065,7 +1234,7 @@ class LoraLoaderMixin:
state_dict (`dict`): state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers. additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`): network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details. See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`): text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into. The text encoder model to load the LoRA layers into.
...@@ -1134,14 +1303,19 @@ class LoraLoaderMixin: ...@@ -1134,14 +1303,19 @@ class LoraLoaderMixin:
].shape[1] ].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp) cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
)
# set correct dtype & device # set correct dtype & device
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items() for k, v in text_encoder_lora_state_dict.items()
} }
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0: if len(load_state_dict_results.unexpected_keys) != 0:
raise ValueError( raise ValueError(
...@@ -1176,7 +1350,7 @@ class LoraLoaderMixin: ...@@ -1176,7 +1350,7 @@ class LoraLoaderMixin:
cls, cls,
text_encoder, text_encoder,
lora_scale=1, lora_scale=1,
network_alpha=None, network_alphas=None,
rank=4, rank=4,
dtype=None, dtype=None,
patch_mlp=False, patch_mlp=False,
...@@ -1189,37 +1363,46 @@ class LoraLoaderMixin: ...@@ -1189,37 +1363,46 @@ class LoraLoaderMixin:
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
lora_parameters = [] lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.get(name + ".k.proj.alpha")
key_alpha = network_alphas.get(name + ".q.proj.alpha")
value_alpha = network_alphas.get(name + ".v.proj.alpha")
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
for _, attn_module in text_encoder_attn_modules(text_encoder):
attn_module.q_proj = PatchedLoraProjection( attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
attn_module.k_proj = PatchedLoraProjection( attn_module.k_proj = PatchedLoraProjection(
attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
attn_module.v_proj = PatchedLoraProjection( attn_module.v_proj = PatchedLoraProjection(
attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
attn_module.out_proj = PatchedLoraProjection( attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp: if patch_mlp:
for _, mlp_module in text_encoder_mlp_modules(text_encoder): for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
mlp_module.fc1 = PatchedLoraProjection( mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
mlp_module.fc2 = PatchedLoraProjection( mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
...@@ -1326,77 +1509,163 @@ class LoraLoaderMixin: ...@@ -1326,77 +1509,163 @@ class LoraLoaderMixin:
def _convert_kohya_lora_to_diffusers(cls, state_dict): def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {} unet_state_dict = {}
te_state_dict = {} te_state_dict = {}
network_alpha = None te2_state_dict = {}
unloaded_keys = [] network_alphas = {}
for key, value in state_dict.items(): # every down weight has a corresponding up weight and potentially an alpha weight
if "hada" in key or "skip" in key: lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
unloaded_keys.append(key) for key in lora_keys:
elif "lora_down" in key: lora_name = key.split(".")[0]
lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight"
lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha"
lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in state_dict: # if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item() # alpha = state_dict.pop(lora_name_alpha).item()
if network_alpha is None: # network_alphas.update({lora_name_alpha: alpha})
network_alpha = alpha
elif network_alpha != alpha: if lora_name.startswith("lora_unet_"):
raise ValueError("Network alpha is not consistent") diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if lora_name.startswith("lora_unet_"): if "input.blocks" in diffusers_name:
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block") diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
logger.info("Kohya-style checkpoint detected.") diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
if len(unloaded_keys) > 0: diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
logger.warning( diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te1_"):
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te2_"):
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Rename the alphas so that they can be mapped appropriately.
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha})
if len(state_dict) > 0:
raise ValueError(
f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
) )
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} logger.info("Kohya-style checkpoint detected.")
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {
f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict} new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha return new_state_dict, network_alphas
def unload_lora_weights(self): def unload_lora_weights(self):
""" """
......
...@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module): ...@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
self.rank = rank self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__( def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
...@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module):
""" """
def __init__( def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None self,
hidden_size,
cross_attention_dim,
rank=4,
attention_op: Optional[Callable] = None,
network_alpha=None,
**kwargs,
): ):
super().__init__() super().__init__()
...@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.rank = rank self.rank = rank
self.attention_op = attention_op self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__( def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
...@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
...@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module):
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
self.rank = rank self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states residual = hidden_states
......
...@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module): ...@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module):
class LoRAConv2dLayer(nn.Module): class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None): def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
):
super().__init__() super().__init__()
if rank > min(in_features, out_features): if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha self.network_alpha = network_alpha
......
...@@ -23,6 +23,7 @@ import torch.nn.functional as F ...@@ -23,6 +23,7 @@ import torch.nn.functional as F
from .activations import get_activation from .activations import get_activation
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
from .lora import LoRACompatibleConv, LoRACompatibleLinear
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
...@@ -126,7 +127,7 @@ class Upsample2D(nn.Module): ...@@ -126,7 +127,7 @@ class Upsample2D(nn.Module):
if use_conv_transpose: if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
...@@ -196,7 +197,7 @@ class Downsample2D(nn.Module): ...@@ -196,7 +197,7 @@ class Downsample2D(nn.Module):
self.name = name self.name = name
if use_conv: if use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
...@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module): ...@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module):
else: else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None: if temb_channels is not None:
if self.time_embedding_norm == "default": if self.time_embedding_norm == "default":
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift": elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None self.time_emb_proj = None
else: else:
...@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module): ...@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity) self.nonlinearity = get_activation(non_linearity)
...@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module): ...@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None self.conv_shortcut = None
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d( self.conv_shortcut = LoRACompatibleConv(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
) )
......
...@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings ...@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection: if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim) self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
else: else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
...@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
# TODO: should use out_channels for continuous projections # TODO: should use out_channels for continuous projections
if use_linear_projection: if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels) self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
else: else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
......
...@@ -88,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -88,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -866,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -866,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# Overrride to properly handle the loading and unloading of the additional text encoder. # Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # We could have accessed the unet config from `lora_state_dict()` too. We pass
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
text_encoder_state_dict, text_encoder_state_dict,
network_alpha=network_alpha, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
...@@ -883,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -883,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
if len(text_encoder_2_state_dict) > 0: if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
text_encoder_2_state_dict, text_encoder_2_state_dict,
network_alpha=network_alpha, network_alphas=network_alphas,
text_encoder=self.text_encoder_2, text_encoder=self.text_encoder_2,
prefix="text_encoder_2", prefix="text_encoder_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
......
...@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images ).images
images = images[0, -3:, -3:, -1].flatten() images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment