Unverified Commit 4a4cdd6b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Feat] Support SDXL Kohya-style LoRA (#4287)



* sdxl lora changes.

* better name replacement.

* better replacement.

* debugging

* debugging

* debugging

* debugging

* debugging

* remove print.

* print state dict keys.

* print

* distingisuih better

* debuggable.

* fxi: tyests

* fix: arg from training script.

* access from class.

* run style

* debug

* save intermediate

* some simplifications for SDXL LoRA

* styling

* unet config is not needed in diffusers format.

* fix: dynamic SGM block mapping for SDXL kohya loras (#4322)

* Use lora compatible layers for linear proj_in/proj_out (#4323)

* improve condition for using the sgm_diffusers mapping

* informative comment.

* load compatible keys and embedding layer maaping.

* Get SDXL 1.0 example lora to load

* simplify

* specif ranks and hidden sizes.

* better handling of k rank and hidden

* debug

* debug

* debug

* debug

* debug

* fix: alpha keys

* add check for handling LoRAAttnAddedKVProcessor

* sanity comment

* modifications for text encoder SDXL

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* up

* up

* up

* up

* up

* up

* unneeded comments.

* unneeded comments.

* kwargs for the other attention processors.

* kwargs for the other attention processors.

* debugging

* debugging

* debugging

* debugging

* improve

* debugging

* debugging

* more print

* Fix alphas

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up

* clean up.

* debugging

* fix: text

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarBatuhan Taskaya <batuhan@python.org>
parent b7b6d613
...@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: ...@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
lora_model_id = "sayakpaul/civitai-light-shadow-lora" lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors" lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
``` ```
\ No newline at end of file
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
Here are some example checkpoints we tried out:
* SDXL 0.9:
* https://civitai.com/models/22279?modelVersionId=118556
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
* https://civitai.com/models/108448/daiton-sdxl-test
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
* SDXL 1.0:
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
Here is an example of how to perform inference with these checkpoints in `diffusers`:
```python
from diffusers import DiffusionPipeline
import torch
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
generator = torch.manual_seed(2947883060)
num_inference_steps = 30
guidance_scale = 7
image = pipeline(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
generator=generator, guidance_scale=guidance_scale
).images[0]
image.save("Kamepan.png")
```
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
### Known limitations specific to the Kohya-styled LoRAs
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
\ No newline at end of file
...@@ -925,10 +925,10 @@ def main(args): ...@@ -925,10 +925,10 @@ def main(args):
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
) )
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
......
...@@ -825,13 +825,13 @@ def main(args): ...@@ -825,13 +825,13 @@ def main(args):
else: else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
) )
LoraLoaderMixin.load_lora_into_text_encoder( LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
) )
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
......
This diff is collapsed.
...@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module): ...@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
self.rank = rank self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__( def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
...@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module):
""" """
def __init__( def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None self,
hidden_size,
cross_attention_dim,
rank=4,
attention_op: Optional[Callable] = None,
network_alpha=None,
**kwargs,
): ):
super().__init__() super().__init__()
...@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.rank = rank self.rank = rank
self.attention_op = attention_op self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__( def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
...@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
...@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module):
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
self.rank = rank self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states residual = hidden_states
......
...@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module): ...@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module):
class LoRAConv2dLayer(nn.Module): class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None): def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
):
super().__init__() super().__init__()
if rank > min(in_features, out_features): if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha self.network_alpha = network_alpha
......
...@@ -23,6 +23,7 @@ import torch.nn.functional as F ...@@ -23,6 +23,7 @@ import torch.nn.functional as F
from .activations import get_activation from .activations import get_activation
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
from .lora import LoRACompatibleConv, LoRACompatibleLinear
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
...@@ -126,7 +127,7 @@ class Upsample2D(nn.Module): ...@@ -126,7 +127,7 @@ class Upsample2D(nn.Module):
if use_conv_transpose: if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
...@@ -196,7 +197,7 @@ class Downsample2D(nn.Module): ...@@ -196,7 +197,7 @@ class Downsample2D(nn.Module):
self.name = name self.name = name
if use_conv: if use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
...@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module): ...@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module):
else: else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None: if temb_channels is not None:
if self.time_embedding_norm == "default": if self.time_embedding_norm == "default":
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift": elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None self.time_emb_proj = None
else: else:
...@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module): ...@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity) self.nonlinearity = get_activation(non_linearity)
...@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module): ...@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None self.conv_shortcut = None
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d( self.conv_shortcut = LoRACompatibleConv(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
) )
......
...@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings ...@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection: if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim) self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
else: else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
...@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
# TODO: should use out_channels for continuous projections # TODO: should use out_channels for continuous projections
if use_linear_projection: if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels) self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
else: else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
......
...@@ -88,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -88,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
In addition the pipeline inherits the following loading methods: In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods: as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -866,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -866,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# Overrride to properly handle the loading and unloading of the additional text encoder. # Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # We could have accessed the unet config from `lora_state_dict()` too. We pass
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
text_encoder_state_dict, text_encoder_state_dict,
network_alpha=network_alpha, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
...@@ -883,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -883,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
if len(text_encoder_2_state_dict) > 0: if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
text_encoder_2_state_dict, text_encoder_2_state_dict,
network_alpha=network_alpha, network_alphas=network_alphas,
text_encoder=self.text_encoder_2, text_encoder=self.text_encoder_2,
prefix="text_encoder_2", prefix="text_encoder_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
......
...@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images ).images
images = images[0, -3:, -3:, -1].flatten() images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment