Unverified Commit c2a28c34 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Refactor LoRA (#3778)



* refactor to support patching LoRA into T5

instantiate the lora linear layer on the same device as the regular linear layer

get lora rank from state dict

tests

fmt

can create lora layer in float32 even when rest of model is float16

fix loading model hook

remove load_lora_weights_ and T5 dispatching

remove Unet#attn_processors_state_dict

docstrings

* text encoder monkeypatch class method

* fix test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 78922ed7
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import shutil import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict
import numpy as np import numpy as np
import torch import torch
...@@ -50,7 +51,10 @@ from diffusers import ( ...@@ -50,7 +51,10 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.loaders import (
LoraLoaderMixin,
text_encoder_lora_state_dict,
)
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
...@@ -60,7 +64,7 @@ from diffusers.models.attention_processor import ( ...@@ -60,7 +64,7 @@ from diffusers.models.attention_processor import (
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds return prompt_embeds
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -833,6 +853,7 @@ def main(args): ...@@ -833,6 +853,7 @@ def main(args):
# Set correct lora layers # Set correct lora layers
unet_lora_attn_procs = {} unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items(): for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
...@@ -850,35 +871,18 @@ def main(args): ...@@ -850,35 +871,18 @@ def main(args):
lora_attn_processor_class = ( lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
cross_attention_dim=cross_attention_dim, unet_lora_attn_procs[name] = module
rank=args.rank, unet_lora_parameters.extend(module.parameters())
)
unet.set_attn_processor(unet_lora_attn_procs) unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this, # So, instead, we monkey-patch the forward calls of its attention-blocks.
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
if args.train_text_encoder: if args.train_text_encoder:
text_lora_attn_procs = {} # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
for name, module in text_encoder.named_modules(): text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
del temp_pipeline
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
...@@ -887,23 +891,13 @@ def main(args): ...@@ -887,23 +891,13 @@ def main(args):
unet_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None text_encoder_lora_layers_to_save = None
if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
for model in models: for model in models:
state_dict = model.state_dict() if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
if ( elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers is not None text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
and text_encoder_keys is not None else:
and state_dict.keys() == text_encoder_keys raise ValueError(f"unexpected save model: {model.__class__}")
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
...@@ -915,27 +909,24 @@ def main(args): ...@@ -915,27 +909,24 @@ def main(args):
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose unet_ = None
# so that the we don't accidentally override the LoRA layers of text_encoder_ = None
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
# load lora weights into models while len(models) > 0:
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) model = models.pop()
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
# delete temporary pipeline and pop models if isinstance(model, type(accelerator.unwrap_model(unet))):
del temp_pipeline unet_ = model
for _ in range(len(models)): elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
models.pop() text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
...@@ -965,9 +956,9 @@ def main(args): ...@@ -965,9 +956,9 @@ def main(args):
# Optimizer creation # Optimizer creation
params_to_optimize = ( params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder if args.train_text_encoder
else unet_lora_layers.parameters() else unet_lora_parameters
) )
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
...@@ -1056,12 +1047,12 @@ def main(args): ...@@ -1056,12 +1047,12 @@ def main(args):
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
if args.train_text_encoder: if args.train_text_encoder:
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler
) )
else: else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
...@@ -1210,9 +1201,9 @@ def main(args): ...@@ -1210,9 +1201,9 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = ( params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder if args.train_text_encoder
else unet_lora_layers.parameters() else unet_lora_parameters
) )
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
...@@ -1301,14 +1292,16 @@ def main(args): ...@@ -1301,14 +1292,16 @@ def main(args):
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if args.validation_images is None: if args.validation_images is None:
images = [ images = []
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images):
for _ in range(args.num_validation_images) with torch.cuda.amp.autocast():
] image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else: else:
images = [] images = []
for image in args.validation_images: for image in args.validation_images:
image = Image.open(image) image = Image.open(image)
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, image=image, generator=generator).images[0] image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image) images.append(image)
...@@ -1332,12 +1325,16 @@ def main(args): ...@@ -1332,12 +1325,16 @@ def main(args):
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) unet_lora_layers = unet_attn_processors_state_dict(unet)
if text_encoder is not None: if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32) text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
else:
text_encoder_lora_layers = None
LoraLoaderMixin.save_lora_weights( LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
......
...@@ -20,6 +20,7 @@ from typing import Callable, Dict, List, Optional, Union ...@@ -20,6 +20,7 @@ from typing import Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from torch import nn
from .models.attention_processor import ( from .models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
...@@ -29,6 +30,7 @@ from .models.attention_processor import ( ...@@ -29,6 +30,7 @@ from .models.attention_processor import (
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
LoRALinearLayer,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
...@@ -36,7 +38,6 @@ from .models.attention_processor import ( ...@@ -36,7 +38,6 @@ from .models.attention_processor import (
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
TEXT_ENCODER_ATTN_MODULE,
_get_model_file, _get_model_file,
deprecate, deprecate,
is_safetensors_available, is_safetensors_available,
...@@ -49,7 +50,7 @@ if is_safetensors_available(): ...@@ -49,7 +50,7 @@ if is_safetensors_available():
import safetensors import safetensors
if is_transformers_available(): if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -67,6 +68,64 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" ...@@ -67,6 +68,64 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
self.regular_linear_layer = regular_linear_layer
device = self.regular_linear_layer.weight.device
if dtype is None:
dtype = self.regular_linear_layer.weight.dtype
self.lora_linear_layer = LoRALinearLayer(
self.regular_linear_layer.in_features,
self.regular_linear_layer.out_features,
network_alpha=network_alpha,
device=device,
dtype=dtype,
rank=rank,
)
self.lora_scale = lora_scale
def forward(self, input):
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
def text_encoder_attn_modules(text_encoder):
attn_modules = []
if isinstance(text_encoder, CLIPTextModel):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
return attn_modules
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
class AttnProcsLayers(torch.nn.Module): class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]): def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__() super().__init__()
...@@ -744,9 +803,48 @@ class LoraLoaderMixin: ...@@ -744,9 +803,48 @@ class LoraLoaderMixin:
unet_name = UNET_NAME unet_name = UNET_NAME
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
`self.unet`.
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs:
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
self.load_lora_into_text_encoder(
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
)
@classmethod
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r""" r"""
Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and Return state dict for lora weights
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters: Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
...@@ -801,9 +899,6 @@ class LoraLoaderMixin: ...@@ -801,9 +899,6 @@ class LoraLoaderMixin:
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
# set lora scale to a reasonable default
self._lora_scale = 1.0
if use_safetensors and not is_safetensors_available(): if use_safetensors and not is_safetensors_available():
raise ValueError( raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
...@@ -840,7 +935,7 @@ class LoraLoaderMixin: ...@@ -840,7 +935,7 @@ class LoraLoaderMixin:
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = safetensors.torch.load_file(model_file, device="cpu")
except IOError as e: except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle: if not allow_pickle:
raise e raise e
# try loading non-safetensors weights # try loading non-safetensors weights
...@@ -866,286 +961,185 @@ class LoraLoaderMixin: ...@@ -866,286 +961,185 @@ class LoraLoaderMixin:
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alpha
@classmethod
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
"""
This will load the LoRA layers specified in `state_dict` into `unet`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alpha (`float`):
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
"""
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(self.unet_name)] unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
logger.info(f"Loading {self.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
unet_lora_state_dict = { unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
text_encoder_lora_state_dict = {
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
} }
if len(text_encoder_lora_state_dict) > 0: unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
logger.info(f"Loading {self.text_encoder_name}.")
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha
)
self._modify_text_encoder(attn_procs_text_encoder)
# save lora attn procs of text encoder so that it can be easily retrieved
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
# Otherwise, we're dealing with the old format. This means the `state_dict` should only # Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix. # contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all( elif not all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
): ):
self.unet.load_attn_procs(state_dict) unet.load_attn_procs(state_dict)
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) warnings.warn(warn_message)
@property @classmethod
def lora_scale(self) -> float: def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
# property function that returns the lora scale which can be set at run time by the pipeline. """
# if _lora_scale has not been set, return 1 This will load the LoRA layers specified in `state_dict` into `text_encoder`
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
@property
def text_encoder_lora_attn_procs(self):
if hasattr(self, "_text_encoder_lora_attn_procs"):
return self._text_encoder_lora_attn_procs
return
def _remove_text_encoder_monkey_patch(self):
# Loop over the CLIPAttention module of text_encoder
for name, attn_module in self.text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
# Loop over the LoRA layers
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
# Retrieve the q/k/v/out projection of CLIPAttention
module = attn_module.get_submodule(text_encoder_attr)
if hasattr(module, "old_forward"):
# restore original `forward` to remove monkey-patch
module.forward = module.old_forward
delattr(module, "old_forward")
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
Parameters: Parameters:
attn_processors: Dict[str, `LoRAAttnProcessor`]: state_dict (`dict`):
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
""" """
# First, remove any monkey-patch that might have been applied before # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
self._remove_text_encoder_monkey_patch() # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
# Loop over the CLIPAttention module of text_encoder keys = list(state_dict.keys())
for name, attn_module in self.text_encoder.named_modules(): if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
if name.endswith(TEXT_ENCODER_ATTN_MODULE): # Load the layers corresponding to text encoder and make necessary adjustments.
# Loop over the LoRA layers text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)]
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): text_encoder_lora_state_dict = {
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
module = attn_module.get_submodule(text_encoder_attr)
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
# save old_forward to module that can be used to remove monkey-patch
old_forward = module.old_forward = module.forward
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
result = old_forward(x) + self.lora_scale * lora_layer(x)
return result
return new_forward
# Monkey-patch.
module.forward = make_new_forward(old_forward, lora_layer)
@property
def _lora_attn_processor_attr_to_text_encoder_attr(self):
return {
"to_q_lora": "q_proj",
"to_k_lora": "k_proj",
"to_v_lora": "v_proj",
"to_out_lora": "out_proj",
} }
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {cls.text_encoder_name}.")
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
#
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e.
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
#
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict.
# They're in `PatchedLoraProjection.lora_linear_layer` now.
for name, _ in text_encoder_attn_modules(text_encoder):
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
rank = text_encoder_lora_state_dict[
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
].shape[1]
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
def _load_text_encoder_attn_procs( # set correct dtype & device
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs text_encoder_lora_state_dict = {
): k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
r""" for k, v in text_encoder_lora_state_dict.items()
Load pretrained attention processor layers for }
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
<Tip warning={true}>
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
`./my_model_directory/`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
Returns:
`Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
[`LoRAAttnProcessor`].
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
network_alpha = kwargs.pop("network_alpha", None)
if use_safetensors and not is_safetensors_available(): load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0:
raise ValueError( raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
) )
allow_pickle = False @property
if use_safetensors is None: def lora_scale(self) -> float:
use_safetensors = is_safetensors_available() # property function that returns the lora scale which can be set at run time by the pipeline.
allow_pickle = True # if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
user_agent = { def _remove_text_encoder_monkey_patch(self):
"file_type": "attn_procs_weights", self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
"framework": "pytorch",
}
model_file = None @classmethod
if not isinstance(pretrained_model_name_or_path_or_dict, dict): def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
# Let's first try to load .safetensors weights for _, attn_module in text_encoder_attn_modules(text_encoder):
if (use_safetensors and weight_name is None) or ( if isinstance(attn_module.q_proj, PatchedLoraProjection):
weight_name is not None and weight_name.endswith(".safetensors") attn_module.q_proj = attn_module.q_proj.regular_linear_layer
): attn_module.k_proj = attn_module.k_proj.regular_linear_layer
try: attn_module.v_proj = attn_module.v_proj.regular_linear_layer
model_file = _get_model_file( attn_module.out_proj = attn_module.out_proj.regular_linear_layer
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except IOError as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
# fill attn processors @classmethod
attn_processors = {} def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
is_lora = all("lora" in k for k in state_dict.keys()) # First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
if is_lora: lora_parameters = []
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in lora_grouped_dict.items(): for _, attn_module in text_encoder_attn_modules(text_encoder):
rank = value_dict["to_k_lora.down.weight"].shape[0] attn_module.q_proj = PatchedLoraProjection(
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] )
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
attn_processor_class = ( attn_module.k_proj = PatchedLoraProjection(
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
) )
attn_processors[key] = attn_processor_class( lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, attn_module.v_proj = PatchedLoraProjection(
rank=rank, attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
network_alpha=network_alpha,
) )
attn_processors[key].load_state_dict(value_dict) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
else: attn_module.out_proj = PatchedLoraProjection(
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
# set correct dtype & device return lora_parameters
attn_processors = {
k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
}
return attn_processors
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
...@@ -1225,7 +1219,8 @@ class LoraLoaderMixin: ...@@ -1225,7 +1219,8 @@ class LoraLoaderMixin:
save_function(state_dict, os.path.join(save_directory, weight_name)) save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def _convert_kohya_lora_to_diffusers(self, state_dict): @classmethod
def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {} unet_state_dict = {}
te_state_dict = {} te_state_dict = {}
network_alpha = None network_alpha = None
......
...@@ -506,14 +506,14 @@ class AttnProcessor: ...@@ -506,14 +506,14 @@ class AttnProcessor:
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__() super().__init__()
if rank > min(in_features, out_features): if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False) self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha self.network_alpha = network_alpha
......
...@@ -30,7 +30,6 @@ from .constants import ( ...@@ -30,7 +30,6 @@ from .constants import (
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
......
...@@ -30,4 +30,3 @@ DIFFUSERS_CACHE = default_cache_path ...@@ -30,4 +30,3 @@ DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
...@@ -12,18 +12,19 @@ ...@@ -12,18 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import os import os
import tempfile import tempfile
import unittest import unittest
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
Attention, Attention,
AttnProcessor, AttnProcessor,
...@@ -33,7 +34,8 @@ from diffusers.models.attention_processor import ( ...@@ -33,7 +34,8 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device from diffusers.utils import floats_tensor, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow
def create_unet_lora_layers(unet: nn.Module): def create_unet_lora_layers(unet: nn.Module):
...@@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): ...@@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
lora_attn_processor_class = ( lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
for name, module in text_encoder.named_modules(): for name, module in text_encoder_attn_modules(text_encoder):
if name.endswith(TEXT_ENCODER_ATTN_MODULE): if isinstance(module.out_proj, nn.Linear):
text_lora_attn_procs[name] = lora_attn_processor_class( out_features = module.out_proj.out_features
hidden_size=module.out_proj.out_features, cross_attention_dim=None elif isinstance(module.out_proj, PatchedLoraProjection):
) out_features = module.out_proj.regular_linear_layer.out_features
else:
assert False, module.out_proj.__class__
text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None)
return text_lora_attn_procs return text_lora_attn_procs
...@@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): ...@@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
return text_encoder_lora_layers return text_encoder_lora_layers
def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): def set_lora_weights(text_lora_attn_parameters, randn_weight=False):
for _, attn_proc in text_lora_attn_procs.items(): with torch.no_grad():
# set up.weights for parameter in text_lora_attn_parameters:
for layer_name, layer_module in attn_proc.named_modules(): if randn_weight:
if layer_name.endswith("_lora"): parameter[:] = torch.randn_like(parameter)
weight = ( else:
torch.randn_like(layer_module.up.weight) torch.zero_(parameter)
if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
layer_module.up.weight = torch.nn.Parameter(weight)
class LoraLoaderMixinTests(unittest.TestCase): class LoraLoaderMixinTests(unittest.TestCase):
...@@ -281,16 +283,10 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -281,16 +283,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32) assert outputs_without_lora.shape == (1, 77, 32)
# create lora_attn_procs with zeroed out up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=False)
# monkey patch # monkey patch
pipe._modify_text_encoder(text_attn_procs) params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. set_lora_weights(params, randn_weight=False)
del text_attn_procs
gc.collect()
# inference with lora # inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
...@@ -301,15 +297,12 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -301,15 +297,12 @@ class LoraLoaderMixinTests(unittest.TestCase):
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"
# create lora_attn_procs with randn up.weights # create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)
# monkey patch # monkey patch
pipe._modify_text_encoder(text_attn_procs) params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. set_lora_weights(params, randn_weight=True)
del text_attn_procs
gc.collect()
# inference with lora # inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
...@@ -329,16 +322,10 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -329,16 +322,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32) assert outputs_without_lora.shape == (1, 77, 32)
# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)
# monkey patch # monkey patch
pipe._modify_text_encoder(text_attn_procs) params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. set_lora_weights(params, randn_weight=True)
del text_attn_procs
gc.collect()
# inference with lora # inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
...@@ -467,3 +454,86 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -467,3 +454,86 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match. # Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
@slow
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe(
"A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_dreambooth_text_encoder_new_format(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/lora-trained"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_a1111(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to(
torch_device
)
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_vanilla_funetuning(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment