Unverified Commit 816ca004 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Fix SDXL text encoder LoRAs (#4371)



* temporarily disable text encoder loras.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debbuging.

* modify doc.

* rename tests.

* print slices.

* fix: assertions

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fef8d2f7
......@@ -401,5 +401,4 @@ Thanks to [@isidentical](https://github.com/isidentical) for helping us on integ
### Known limitations specific to the Kohya-styled LoRAs
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
\ No newline at end of file
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import re
import warnings
......@@ -258,6 +259,7 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
is_network_alphas_none = network_alphas is None
if use_safetensors and not is_safetensors_available():
raise ValueError(
......@@ -349,13 +351,20 @@ class UNet2DConditionLoadersMixin:
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas:
network_alphas_ = copy.deepcopy(network_alphas)
for k in network_alphas_:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
if not is_network_alphas_none:
if len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
if len(state_dict) > 0:
raise ValueError(
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items():
......@@ -434,14 +443,6 @@ class UNet2DConditionLoadersMixin:
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
out_rank=rank_mapping.get("to_out_lora.down.weight"),
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
# rank=rank_mapping.get("to_k_lora.down.weight", None),
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
)
else:
attn_processors[key] = attn_processor_class(
......@@ -496,9 +497,6 @@ class UNet2DConditionLoadersMixin:
# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
target_module.set_lora_layer(lora_layer)
# It should raise an error if we don't have a set lora here
# if hasattr(target_module, "set_lora_layer"):
# target_module.set_lora_layer(lora_layer)
def save_attn_procs(
self,
......@@ -1251,9 +1249,10 @@ class LoraLoaderMixin:
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix)]
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
......@@ -1303,6 +1302,14 @@ class LoraLoaderMixin:
].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
cls._modify_text_encoder(
text_encoder,
lora_scale,
......@@ -1364,12 +1371,13 @@ class LoraLoaderMixin:
lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
is_network_alphas_populated = len(network_alphas) > 0
for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.get(name + ".k.proj.alpha")
key_alpha = network_alphas.get(name + ".q.proj.alpha")
value_alpha = network_alphas.get(name + ".v.proj.alpha")
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
......@@ -1387,14 +1395,14 @@ class LoraLoaderMixin:
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
......@@ -1406,6 +1414,11 @@ class LoraLoaderMixin:
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
return lora_parameters
@classmethod
......@@ -1519,10 +1532,6 @@ class LoraLoaderMixin:
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
# if lora_name_alpha in state_dict:
# alpha = state_dict.pop(lora_name_alpha).item()
# network_alphas.update({lora_name_alpha: alpha})
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
......
......@@ -737,7 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-4))
......@@ -760,7 +760,7 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_unload_lora(self):
def test_unload_kohya_lora(self):
generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment