Unverified Commit ff8f5808 authored by Batuhan Taskaya's avatar Batuhan Taskaya Committed by GitHub
Browse files

Load Kohya-ss style LoRAs with auxilary states (#4147)



* Support to load Kohya-ss style LoRA file format (without restrictions)
Co-Authored-By: default avatarTakuma Mori <takuma104@gmail.com>
Co-Authored-By: default avatarSayak Paul <spsayakpaul@gmail.com>

* tmp: add sdxl to mlp_modules

---------
Co-authored-by: default avatarTakuma Mori <takuma104@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 161449d5
...@@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License. ...@@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License.
<Tip warning={true}> <Tip warning={true}>
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also This is an experimental feature. Its APIs can change in future.
support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage.
</Tip> </Tip>
...@@ -286,6 +285,8 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pip ...@@ -286,6 +285,8 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pip
## Supporting A1111 themed LoRA checkpoints from Diffusers ## Supporting A1111 themed LoRA checkpoints from Diffusers
This support was made possible because of our amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical).
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity. LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/) In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
......
...@@ -25,6 +25,7 @@ import torch.nn.functional as F ...@@ -25,6 +25,7 @@ import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from torch import nn from torch import nn
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
...@@ -56,6 +57,7 @@ UNET_NAME = "unet" ...@@ -56,6 +57,7 @@ UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TOTAL_EXAMPLE_KEYS = 5
TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
...@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder): ...@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules return attn_modules
def text_encoder_mlp_modules(text_encoder):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
return mlp_modules
def text_encoder_lora_state_dict(text_encoder): def text_encoder_lora_state_dict(text_encoder):
state_dict = {} state_dict = {}
...@@ -304,6 +320,7 @@ class UNet2DConditionLoadersMixin: ...@@ -304,6 +320,7 @@ class UNet2DConditionLoadersMixin:
# fill attn processors # fill attn processors
attn_processors = {} attn_processors = {}
non_attn_lora_layers = []
is_lora = all("lora" in k for k in state_dict.keys()) is_lora = all("lora" in k for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
...@@ -327,13 +344,33 @@ class UNet2DConditionLoadersMixin: ...@@ -327,13 +344,33 @@ class UNet2DConditionLoadersMixin:
lora_grouped_dict[attn_processor_key][sub_key] = value lora_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in lora_grouped_dict.items(): for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processor = self attn_processor = self
for sub_key in key.split("."): for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key) attn_processor = getattr(attn_processor, sub_key)
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora))
continue
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
if isinstance( if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
): ):
...@@ -390,10 +427,16 @@ class UNet2DConditionLoadersMixin: ...@@ -390,10 +427,16 @@ class UNet2DConditionLoadersMixin:
# set correct dtype & device # set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers]
# set layers # set layers
self.set_attn_processor(attn_processors) self.set_attn_processor(attn_processors)
# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
if hasattr(target_module, "set_lora_layer"):
target_module.set_lora_layer(lora_layer)
def save_attn_procs( def save_attn_procs(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
...@@ -840,7 +883,10 @@ class LoraLoaderMixin: ...@@ -840,7 +883,10 @@ class LoraLoaderMixin:
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
) )
@classmethod @classmethod
...@@ -1049,6 +1095,7 @@ class LoraLoaderMixin: ...@@ -1049,6 +1095,7 @@ class LoraLoaderMixin:
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
} }
if len(text_encoder_lora_state_dict) > 0: if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.") logger.info(f"Loading {prefix}.")
...@@ -1092,8 +1139,9 @@ class LoraLoaderMixin: ...@@ -1092,8 +1139,9 @@ class LoraLoaderMixin:
rank = text_encoder_lora_state_dict[ rank = text_encoder_lora_state_dict[
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
].shape[1] ].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
# set correct dtype & device # set correct dtype & device
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
...@@ -1125,8 +1173,21 @@ class LoraLoaderMixin: ...@@ -1125,8 +1173,21 @@ class LoraLoaderMixin:
attn_module.v_proj = attn_module.v_proj.regular_linear_layer attn_module.v_proj = attn_module.v_proj.regular_linear_layer
attn_module.out_proj = attn_module.out_proj.regular_linear_layer attn_module.out_proj = attn_module.out_proj.regular_linear_layer
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
@classmethod @classmethod
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None): def _modify_text_encoder(
cls,
text_encoder,
lora_scale=1,
network_alpha=None,
rank=4,
dtype=None,
patch_mlp=False,
):
r""" r"""
Monkey-patches the forward passes of attention modules of the text encoder. Monkey-patches the forward passes of attention modules of the text encoder.
""" """
...@@ -1157,6 +1218,18 @@ class LoraLoaderMixin: ...@@ -1157,6 +1218,18 @@ class LoraLoaderMixin:
) )
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
return lora_parameters return lora_parameters
@classmethod @classmethod
...@@ -1261,9 +1334,12 @@ class LoraLoaderMixin: ...@@ -1261,9 +1334,12 @@ class LoraLoaderMixin:
unet_state_dict = {} unet_state_dict = {}
te_state_dict = {} te_state_dict = {}
network_alpha = None network_alpha = None
unloaded_keys = []
for key, value in state_dict.items(): for key, value in state_dict.items():
if "lora_down" in key: if "hada" in key or "skip" in key:
unloaded_keys.append(key)
elif "lora_down" in key:
lora_name = key.split(".")[0] lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight" lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha" lora_name_alpha = lora_name + ".alpha"
...@@ -1284,12 +1360,21 @@ class LoraLoaderMixin: ...@@ -1284,12 +1360,21 @@ class LoraLoaderMixin:
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name: if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"): elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("text.model", "text_model")
...@@ -1301,6 +1386,19 @@ class LoraLoaderMixin: ...@@ -1301,6 +1386,19 @@ class LoraLoaderMixin:
if "self_attn" in diffusers_name: if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
logger.info("Kohya-style checkpoint detected.")
if len(unloaded_keys) > 0:
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
logger.warning(
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
)
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
...@@ -1346,6 +1444,10 @@ class LoraLoaderMixin: ...@@ -1346,6 +1444,10 @@ class LoraLoaderMixin:
[attention_proc_class] = unet_attention_classes [attention_proc_class] = unet_attention_classes
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]()) self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
......
...@@ -21,6 +21,7 @@ from ..utils import maybe_allow_in_graph ...@@ -21,6 +21,7 @@ from ..utils import maybe_allow_in_graph
from .activations import get_activation from .activations import get_activation
from .attention_processor import Attention from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear
@maybe_allow_in_graph @maybe_allow_in_graph
...@@ -245,7 +246,7 @@ class FeedForward(nn.Module): ...@@ -245,7 +246,7 @@ class FeedForward(nn.Module):
# project dropout # project dropout
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
# project out # project out
self.net.append(nn.Linear(inner_dim, dim_out)) self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
...@@ -289,7 +290,7 @@ class GEGLU(nn.Module): ...@@ -289,7 +290,7 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int): def __init__(self, dim_in: int, dim_out: int):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2) self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
def gelu(self, gate): def gelu(self, gate):
if gate.device.type != "mps": if gate.device.type != "mps":
......
...@@ -19,6 +19,7 @@ from torch import nn ...@@ -19,6 +19,7 @@ from torch import nn
from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils import deprecate, logging, maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -505,36 +506,6 @@ class AttnProcessor: ...@@ -505,36 +506,6 @@ class AttnProcessor:
return hidden_states return hidden_states
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRAAttnProcessor(nn.Module): class LoRAAttnProcessor(nn.Module):
r""" r"""
Processor for implementing the LoRA attention mechanism. Processor for implementing the LoRA attention mechanism.
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch import nn
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRACompatibleConv(nn.Conv2d):
"""
A convolutional layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
else:
return super().forward(x) + self.lora_layer(x)
class LoRACompatibleLinear(nn.Linear):
"""
A Linear layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
else:
return super().forward(x) + self.lora_layer(x)
...@@ -23,6 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings ...@@ -23,6 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -138,7 +139,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -138,7 +139,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if use_linear_projection: if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim) self.proj_in = nn.Linear(in_channels, inner_dim)
else: else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
...@@ -194,7 +195,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -194,7 +195,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if use_linear_projection: if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels) self.proj_out = nn.Linear(inner_dim, in_channels)
else: else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim) self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
......
...@@ -22,6 +22,7 @@ from torch import nn ...@@ -22,6 +22,7 @@ from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.utils import torch_device from diffusers.utils import torch_device
...@@ -482,7 +483,7 @@ class Transformer2DModelTests(unittest.TestCase): ...@@ -482,7 +483,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
dim = 32 dim = 32
inner_dim = 128 inner_dim = 128
...@@ -506,7 +507,7 @@ class Transformer2DModelTests(unittest.TestCase): ...@@ -506,7 +507,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
dim = 32 dim = 32
inner_dim = 128 inner_dim = 128
......
...@@ -738,7 +738,7 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -738,7 +738,7 @@ class LoraIntegrationTests(unittest.TestCase):
images = images[0, -3:, -3:, -1].flatten() images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245]) expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
...@@ -778,6 +778,7 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -778,6 +778,7 @@ class LoraIntegrationTests(unittest.TestCase):
lora_filename = "Colored_Icons_by_vizsumit.safetensors" lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images = pipe( lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images ).images
...@@ -792,3 +793,50 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -792,3 +793,50 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertFalse(np.allclose(initial_images, lora_images)) self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3)) self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
def test_load_unload_load_kohya_lora(self):
# This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded
# without introducing any side-effects. Even though the test uses a Kohya-style
# LoRA, the underlying adapter handling mechanism is format-agnostic.
generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
torch_device
)
initial_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
initial_images = initial_images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images = lora_images[0, -3:, -3:, -1].flatten()
pipe.unload_lora_weights()
generator = torch.manual_seed(0)
unloaded_lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
# make sure we can load a LoRA again after unloading and they don't have
# any undesired effects.
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images_again = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment