Unverified Commit 8e552bb4 authored by Takuma Mori's avatar Takuma Mori Committed by GitHub
Browse files

Support Kohya-ss style LoRA file format (in a limited capacity) (#3437)



* add _convert_kohya_lora_to_diffusers

* make style

* add scaffold

* match result: unet attention only

* fix monkey-patch for text_encoder

* with CLIPAttention

While the terrible images are no longer produced,
the results do not match those from the hook ver.
This may be due to not setting the network_alpha value.

* add to support network_alpha

* generate diff image

* fix monkey-patch for text_encoder

* add test_text_encoder_lora_monkey_patch()

* verify that it's okay to release the attn_procs

* fix closure version

* add comment

* Revert "fix monkey-patch for text_encoder"

This reverts commit bb9c61e6faecc1935c9c4319c77065837655d616.

* Fix to reuse utility functions

* make LoRAAttnProcessor targets to self_attn

* fix LoRAAttnProcessor target

* make style

* fix split key

* Update src/diffusers/loaders.py

* remove TEXT_ENCODER_TARGET_MODULES loop

* add print memory usage

* remove test_kohya_loras_scaffold.py

* add: doc on LoRA civitai

* remove print statement and refactor in the doc.

* fix state_dict test for kohya-ss style lora

* Apply suggestions from code review
Co-authored-by: default avatarTakuma Mori <takuma104@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 32ea2142
......@@ -272,4 +272,75 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
refer to the respective docstrings.
\ No newline at end of file
refer to the respective docstrings.
## Supporting A1111 themed LoRA checkpoints from Diffusers
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
in Diffusers and perform inference with it.
First, download a checkpoint. We'll use
[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes.
```bash
wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors
```
Next, we initialize a [`~DiffusionPipeline`]:
```python
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
pipeline = StableDiffusionPipeline.from_pretrained(
"gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, use_karras_sigmas=True
)
```
We then load the checkpoint downloaded from CivitAI:
```python
pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors")
```
<Tip warning={true}>
If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed.
</Tip>
And then it's time for running inference:
```python
prompt = "masterpiece, best quality, 1girl, at dusk"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts")
images = pipeline(prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
num_inference_steps=15,
num_images_per_prompt=4,
generator=torch.manual_seed(0)
).images
```
Below is a comparison between the LoRA and the non-LoRA results:
![lora_non_lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_non_lora_comparison.png)
You have a similar checkpoint stored on the Hugging Face Hub, you can load it
directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
```python
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
```
\ No newline at end of file
......@@ -58,7 +58,7 @@ from diffusers.models.attention_processor import (
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
......@@ -861,9 +861,9 @@ def main(args):
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
......
......@@ -72,8 +72,8 @@ class AttnProcsLayers(torch.nn.Module):
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
......@@ -182,6 +182,9 @@ class UNet2DConditionLoadersMixin:
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alpha = kwargs.pop("network_alpha", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
......@@ -287,7 +290,10 @@ class UNet2DConditionLoadersMixin:
attn_processor_class = LoRAAttnProcessor
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
......@@ -774,6 +780,8 @@ class LoraLoaderMixin:
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
......@@ -898,6 +906,11 @@ class LoraLoaderMixin:
else:
state_dict = pretrained_model_name_or_path_or_dict
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
......@@ -909,7 +922,7 @@ class LoraLoaderMixin:
unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
......@@ -918,7 +931,9 @@ class LoraLoaderMixin:
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha
)
self._modify_text_encoder(attn_procs_text_encoder)
# save lora attn procs of text encoder so that it can be easily retrieved
......@@ -954,14 +969,20 @@ class LoraLoaderMixin:
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
attn_processor_name = ".".join(name.split(".")[:-1])
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward
def new_forward(x):
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)
return new_forward
# Monkey-patch.
module.forward = new_forward
module.forward = make_new_forward(old_forward, lora_layer)
def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
......@@ -1048,6 +1069,7 @@ class LoraLoaderMixin:
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
network_alpha = kwargs.pop("network_alpha", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
......@@ -1125,7 +1147,10 @@ class LoraLoaderMixin:
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
......@@ -1219,6 +1244,56 @@ class LoraLoaderMixin:
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def _convert_kohya_lora_to_diffusers(self, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None
for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
if network_alpha is None:
network_alpha = alpha
elif network_alpha != alpha:
raise ValueError("Network alpha is not consistent")
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha
class FromCkptMixin:
"""This helper class allows to directly load .ckpt stable diffusion file_extension
......
......@@ -508,7 +508,7 @@ class AttnProcessor:
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()
if rank > min(in_features, out_features):
......@@ -516,6 +516,10 @@ class LoRALinearLayer(nn.Module):
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
......@@ -527,6 +531,9 @@ class LoRALinearLayer(nn.Module):
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
......@@ -543,17 +550,17 @@ class LoRAAttnProcessor(nn.Module):
The dimension of the LoRA update matrices.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
......@@ -838,19 +845,19 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The dimension of the LoRA update matrices.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
......@@ -1157,7 +1164,9 @@ class LoRAXFormersAttnProcessor(nn.Module):
operator.
"""
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
):
super().__init__()
self.hidden_size = hidden_size
......@@ -1165,10 +1174,10 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.rank = rank
self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
......
......@@ -30,6 +30,7 @@ from .constants import (
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
......
......@@ -31,3 +31,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
......@@ -30,7 +31,7 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device
def create_unet_lora_layers(unet: nn.Module):
......@@ -50,15 +51,35 @@ def create_unet_lora_layers(unet: nn.Module):
return lora_attn_procs, unet_lora_layers
def create_text_encoder_lora_layers(text_encoder: nn.Module):
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
return text_lora_attn_procs
def create_text_encoder_lora_layers(text_encoder: nn.Module):
text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
return text_encoder_lora_layers
def set_lora_up_weights(text_lora_attn_procs, randn_weight=False):
for _, attn_proc in text_lora_attn_procs.items():
# set up.weights
for layer_name, layer_module in attn_proc.named_modules():
if layer_name.endswith("_lora"):
weight = (
torch.randn_like(layer_module.up.weight)
if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
layer_module.up.weight = torch.nn.Parameter(weight)
class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
......@@ -220,6 +241,64 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs
def test_text_encoder_lora_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)
dummy_tokens = self.get_dummy_tokens()
# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)
# create lora_attn_procs with zeroed out up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=False)
# monkey patch
pipe._modify_text_encoder(text_attn_procs)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)
assert torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"
# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)
# monkey patch
pipe._modify_text_encoder(text_attn_procs)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)
assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"
def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment