"Doc/Tutorial_Python.md" did not exist on "c80d4455a26a963ab0cc568b424dddbcd73a71ea"
Unverified Commit ff8f5808 authored by Batuhan Taskaya's avatar Batuhan Taskaya Committed by GitHub
Browse files

Load Kohya-ss style LoRAs with auxilary states (#4147)



* Support to load Kohya-ss style LoRA file format (without restrictions)
Co-Authored-By: default avatarTakuma Mori <takuma104@gmail.com>
Co-Authored-By: default avatarSayak Paul <spsayakpaul@gmail.com>

* tmp: add sdxl to mlp_modules

---------
Co-authored-by: default avatarTakuma Mori <takuma104@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 161449d5
......@@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License.
<Tip warning={true}>
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage.
This is an experimental feature. Its APIs can change in future.
</Tip>
......@@ -286,6 +285,8 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pip
## Supporting A1111 themed LoRA checkpoints from Diffusers
This support was made possible because of our amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical).
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
......
......@@ -25,6 +25,7 @@ import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from torch import nn
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
......@@ -56,6 +57,7 @@ UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TOTAL_EXAMPLE_KEYS = 5
TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
......@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules
def text_encoder_mlp_modules(text_encoder):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
return mlp_modules
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
......@@ -304,6 +320,7 @@ class UNet2DConditionLoadersMixin:
# fill attn processors
attn_processors = {}
non_attn_lora_layers = []
is_lora = all("lora" in k for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
......@@ -327,13 +344,33 @@ class UNet2DConditionLoadersMixin:
lora_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora))
continue
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
......@@ -390,10 +427,16 @@ class UNet2DConditionLoadersMixin:
# set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers]
# set layers
self.set_attn_processor(attn_processors)
# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
if hasattr(target_module, "set_lora_layer"):
target_module.set_lora_layer(lora_layer)
def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
......@@ -840,7 +883,10 @@ class LoraLoaderMixin:
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
self.load_lora_into_text_encoder(
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
)
@classmethod
......@@ -1049,6 +1095,7 @@ class LoraLoaderMixin:
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
......@@ -1092,8 +1139,9 @@ class LoraLoaderMixin:
rank = text_encoder_lora_state_dict[
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
# set correct dtype & device
text_encoder_lora_state_dict = {
......@@ -1125,8 +1173,21 @@ class LoraLoaderMixin:
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
@classmethod
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
def _modify_text_encoder(
cls,
text_encoder,
lora_scale=1,
network_alpha=None,
rank=4,
dtype=None,
patch_mlp=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
......@@ -1157,6 +1218,18 @@ class LoraLoaderMixin:
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
return lora_parameters
@classmethod
......@@ -1261,9 +1334,12 @@ class LoraLoaderMixin:
unet_state_dict = {}
te_state_dict = {}
network_alpha = None
unloaded_keys = []
for key, value in state_dict.items():
if "lora_down" in key:
if "hada" in key or "skip" in key:
unloaded_keys.append(key)
elif "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
......@@ -1284,12 +1360,21 @@ class LoraLoaderMixin:
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
......@@ -1301,6 +1386,19 @@ class LoraLoaderMixin:
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
logger.info("Kohya-style checkpoint detected.")
if len(unloaded_keys) > 0:
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
logger.warning(
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
)
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
......@@ -1346,6 +1444,10 @@ class LoraLoaderMixin:
[attention_proc_class] = unet_attention_classes
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
......
......@@ -21,6 +21,7 @@ from ..utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear
@maybe_allow_in_graph
......@@ -245,7 +246,7 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
......@@ -289,7 +290,7 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != "mps":
......
......@@ -19,6 +19,7 @@ from torch import nn
from ..utils import deprecate, logging, maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -505,36 +506,6 @@ class AttnProcessor:
return hidden_states
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRAAttnProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism.
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch import nn
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class LoRACompatibleConv(nn.Conv2d):
"""
A convolutional layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
else:
return super().forward(x) + self.lora_layer(x)
class LoRACompatibleLinear(nn.Linear):
"""
A Linear layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
def forward(self, x):
if self.lora_layer is None:
return super().forward(x)
else:
return super().forward(x) + self.lora_layer(x)
......@@ -23,6 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv
from .modeling_utils import ModelMixin
......@@ -138,7 +139,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
......@@ -194,7 +195,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
......
......@@ -22,6 +22,7 @@ from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.utils import torch_device
......@@ -482,7 +483,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
dim = 32
inner_dim = 128
......@@ -506,7 +507,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear
dim = 32
inner_dim = 128
......
......@@ -738,7 +738,7 @@ class LoraIntegrationTests(unittest.TestCase):
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245])
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
......@@ -778,6 +778,7 @@ class LoraIntegrationTests(unittest.TestCase):
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
......@@ -792,3 +793,50 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
def test_load_unload_load_kohya_lora(self):
# This test ensures that a Kohya-style LoRA can be safely unloaded and then loaded
# without introducing any side-effects. Even though the test uses a Kohya-style
# LoRA, the underlying adapter handling mechanism is format-agnostic.
generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
torch_device
)
initial_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
initial_images = initial_images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images = lora_images[0, -3:, -3:, -1].flatten()
pipe.unload_lora_weights()
generator = torch.manual_seed(0)
unloaded_lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
# make sure we can load a LoRA again after unloading and they don't have
# any undesired effects.
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
generator = torch.manual_seed(0)
lora_images_again = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment