Unverified Commit e312b230 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] support LyCORIS (#5102)

* better condition.

* debugging

* how about now?

* how about now?

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* support for lycoris.

* style

* add: lycoris test

* fix from_pretrained call.

* fix assertion values.
parent 8263cf00
......@@ -1878,7 +1878,7 @@ class LoraLoaderMixin:
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name:
if "emb" in diffusers_name and "time" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
......@@ -1890,6 +1890,13 @@ class LoraLoaderMixin:
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
......
......@@ -19,6 +19,7 @@ import torch
from torch import nn
from .activations import get_activation
from .lora import LoRACompatibleLinear
def get_timestep_embedding(
......@@ -166,7 +167,7 @@ class TimestepEmbedding(nn.Module):
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
......@@ -179,7 +180,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
......
......@@ -1876,6 +1876,25 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_lycoris(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
).to(torch_device)
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
lora_filename = "edgLycorisMugler-light.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment