Unverified Commit c2a28c34 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Refactor LoRA (#3778)



* refactor to support patching LoRA into T5

instantiate the lora linear layer on the same device as the regular linear layer

get lora rank from state dict

tests

fmt

can create lora layer in float32 even when rest of model is float16

fix loading model hook

remove load_lora_weights_ and T5 dispatching

remove Unet#attn_processors_state_dict

docstrings

* text encoder monkeypatch class method

* fix test

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 78922ed7
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import shutil import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict
import numpy as np import numpy as np
import torch import torch
...@@ -50,7 +51,10 @@ from diffusers import ( ...@@ -50,7 +51,10 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.loaders import (
LoraLoaderMixin,
text_encoder_lora_state_dict,
)
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
...@@ -60,7 +64,7 @@ from diffusers.models.attention_processor import ( ...@@ -60,7 +64,7 @@ from diffusers.models.attention_processor import (
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds return prompt_embeds
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -833,6 +853,7 @@ def main(args): ...@@ -833,6 +853,7 @@ def main(args):
# Set correct lora layers # Set correct lora layers
unet_lora_attn_procs = {} unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items(): for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
...@@ -850,35 +871,18 @@ def main(args): ...@@ -850,35 +871,18 @@ def main(args):
lora_attn_processor_class = ( lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
cross_attention_dim=cross_attention_dim, unet_lora_attn_procs[name] = module
rank=args.rank, unet_lora_parameters.extend(module.parameters())
)
unet.set_attn_processor(unet_lora_attn_procs) unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this, # So, instead, we monkey-patch the forward calls of its attention-blocks.
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
if args.train_text_encoder: if args.train_text_encoder:
text_lora_attn_procs = {} # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
for name, module in text_encoder.named_modules(): text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
del temp_pipeline
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
...@@ -887,23 +891,13 @@ def main(args): ...@@ -887,23 +891,13 @@ def main(args):
unet_lora_layers_to_save = None unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None text_encoder_lora_layers_to_save = None
if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
for model in models: for model in models:
state_dict = model.state_dict() if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
if ( elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers is not None text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
and text_encoder_keys is not None else:
and state_dict.keys() == text_encoder_keys raise ValueError(f"unexpected save model: {model.__class__}")
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() weights.pop()
...@@ -915,27 +909,24 @@ def main(args): ...@@ -915,27 +909,24 @@ def main(args):
) )
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose unet_ = None
# so that the we don't accidentally override the LoRA layers of text_encoder_ = None
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
# load lora weights into models while len(models) > 0:
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) model = models.pop()
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
# delete temporary pipeline and pop models if isinstance(model, type(accelerator.unwrap_model(unet))):
del temp_pipeline unet_ = model
for _ in range(len(models)): elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
models.pop() text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
...@@ -965,9 +956,9 @@ def main(args): ...@@ -965,9 +956,9 @@ def main(args):
# Optimizer creation # Optimizer creation
params_to_optimize = ( params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder if args.train_text_encoder
else unet_lora_layers.parameters() else unet_lora_parameters
) )
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
...@@ -1056,12 +1047,12 @@ def main(args): ...@@ -1056,12 +1047,12 @@ def main(args):
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
if args.train_text_encoder: if args.train_text_encoder:
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler
) )
else: else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
...@@ -1210,9 +1201,9 @@ def main(args): ...@@ -1210,9 +1201,9 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
params_to_clip = ( params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder if args.train_text_encoder
else unet_lora_layers.parameters() else unet_lora_parameters
) )
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
...@@ -1301,15 +1292,17 @@ def main(args): ...@@ -1301,15 +1292,17 @@ def main(args):
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if args.validation_images is None: if args.validation_images is None:
images = [ images = []
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images):
for _ in range(args.num_validation_images) with torch.cuda.amp.autocast():
] image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else: else:
images = [] images = []
for image in args.validation_images: for image in args.validation_images:
image = Image.open(image) image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0] with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image) images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
...@@ -1332,12 +1325,16 @@ def main(args): ...@@ -1332,12 +1325,16 @@ def main(args):
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) unet_lora_layers = unet_attn_processors_state_dict(unet)
if text_encoder is not None: if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32) text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
else:
text_encoder_lora_layers = None
LoraLoaderMixin.save_lora_weights( LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
......
This diff is collapsed.
...@@ -506,14 +506,14 @@ class AttnProcessor: ...@@ -506,14 +506,14 @@ class AttnProcessor:
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__() super().__init__()
if rank > min(in_features, out_features): if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False) self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha self.network_alpha = network_alpha
......
...@@ -30,7 +30,6 @@ from .constants import ( ...@@ -30,7 +30,6 @@ from .constants import (
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
......
...@@ -30,4 +30,3 @@ DIFFUSERS_CACHE = default_cache_path ...@@ -30,4 +30,3 @@ DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
...@@ -12,18 +12,19 @@ ...@@ -12,18 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import os import os
import tempfile import tempfile
import unittest import unittest
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
Attention, Attention,
AttnProcessor, AttnProcessor,
...@@ -33,7 +34,8 @@ from diffusers.models.attention_processor import ( ...@@ -33,7 +34,8 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device from diffusers.utils import floats_tensor, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow
def create_unet_lora_layers(unet: nn.Module): def create_unet_lora_layers(unet: nn.Module):
...@@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): ...@@ -63,11 +65,15 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
lora_attn_processor_class = ( lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
) )
for name, module in text_encoder.named_modules(): for name, module in text_encoder_attn_modules(text_encoder):
if name.endswith(TEXT_ENCODER_ATTN_MODULE): if isinstance(module.out_proj, nn.Linear):
text_lora_attn_procs[name] = lora_attn_processor_class( out_features = module.out_proj.out_features
hidden_size=module.out_proj.out_features, cross_attention_dim=None elif isinstance(module.out_proj, PatchedLoraProjection):
) out_features = module.out_proj.regular_linear_layer.out_features
else:
assert False, module.out_proj.__class__
text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None)
return text_lora_attn_procs return text_lora_attn_procs
...@@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): ...@@ -77,17 +83,13 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
return text_encoder_lora_layers return text_encoder_lora_layers
def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): def set_lora_weights(text_lora_attn_parameters, randn_weight=False):
for _, attn_proc in text_lora_attn_procs.items(): with torch.no_grad():
# set up.weights for parameter in text_lora_attn_parameters:
for layer_name, layer_module in attn_proc.named_modules(): if randn_weight:
if layer_name.endswith("_lora"): parameter[:] = torch.randn_like(parameter)
weight = ( else:
torch.randn_like(layer_module.up.weight) torch.zero_(parameter)
if randn_weight
else torch.zeros_like(layer_module.up.weight)
)
layer_module.up.weight = torch.nn.Parameter(weight)
class LoraLoaderMixinTests(unittest.TestCase): class LoraLoaderMixinTests(unittest.TestCase):
...@@ -281,16 +283,10 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -281,16 +283,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32) assert outputs_without_lora.shape == (1, 77, 32)
# create lora_attn_procs with zeroed out up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=False)
# monkey patch # monkey patch
pipe._modify_text_encoder(text_attn_procs) params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. set_lora_weights(params, randn_weight=False)
del text_attn_procs
gc.collect()
# inference with lora # inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
...@@ -301,15 +297,12 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -301,15 +297,12 @@ class LoraLoaderMixinTests(unittest.TestCase):
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"
# create lora_attn_procs with randn up.weights # create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)
# monkey patch # monkey patch
pipe._modify_text_encoder(text_attn_procs) params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. set_lora_weights(params, randn_weight=True)
del text_attn_procs
gc.collect()
# inference with lora # inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
...@@ -329,16 +322,10 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -329,16 +322,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32) assert outputs_without_lora.shape == (1, 77, 32)
# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)
# monkey patch # monkey patch
pipe._modify_text_encoder(text_attn_procs) params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. set_lora_weights(params, randn_weight=True)
del text_attn_procs
gc.collect()
# inference with lora # inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
...@@ -467,3 +454,86 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -467,3 +454,86 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match. # Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
@slow
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
lora_model_id = "hf-internal-testing/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe(
"A photo of a sks dog floating in the river", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7207, 0.6787, 0.6010, 0.7478, 0.6838, 0.6064, 0.6984, 0.6443, 0.5785])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_dreambooth_text_encoder_new_format(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/lora-trained"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe("A photo of a sks dog", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6628, 0.6138, 0.5390, 0.6625, 0.6130, 0.5463, 0.6166, 0.5788, 0.5359])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_a1111(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None).to(
torch_device
)
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_vanilla_funetuning(self):
generator = torch.Generator().manual_seed(0)
lora_model_id = "hf-internal-testing/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, safety_checker=None)
pipe = pipe.to(torch_device)
pipe.load_lora_weights(lora_model_id)
images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment