Unverified Commit 816ca004 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Fix SDXL text encoder LoRAs (#4371)



* temporarily disable text encoder loras.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debbuging.

* modify doc.

* rename tests.

* print slices.

* fix: assertions

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fef8d2f7
...@@ -401,5 +401,4 @@ Thanks to [@isidentical](https://github.com/isidentical) for helping us on integ ...@@ -401,5 +401,4 @@ Thanks to [@isidentical](https://github.com/isidentical) for helping us on integ
### Known limitations specific to the Kohya-styled LoRAs ### Known limitations specific to the Kohya-styled LoRAs
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736). * When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
import re import re
import warnings import warnings
...@@ -258,6 +259,7 @@ class UNet2DConditionLoadersMixin: ...@@ -258,6 +259,7 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
is_network_alphas_none = network_alphas is None
if use_safetensors and not is_safetensors_available(): if use_safetensors and not is_safetensors_available():
raise ValueError( raise ValueError(
...@@ -349,13 +351,20 @@ class UNet2DConditionLoadersMixin: ...@@ -349,13 +351,20 @@ class UNet2DConditionLoadersMixin:
# Create another `mapped_network_alphas` dictionary so that we can properly map them. # Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None: if network_alphas is not None:
for k in network_alphas: network_alphas_ = copy.deepcopy(network_alphas)
for k in network_alphas_:
if k.replace(".alpha", "") in key: if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas[k]}) mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
if not is_network_alphas_none:
if len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
if len(state_dict) > 0: if len(state_dict) > 0:
raise ValueError( raise ValueError(
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
) )
for key, value_dict in lora_grouped_dict.items(): for key, value_dict in lora_grouped_dict.items():
...@@ -434,14 +443,6 @@ class UNet2DConditionLoadersMixin: ...@@ -434,14 +443,6 @@ class UNet2DConditionLoadersMixin:
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"), v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
out_rank=rank_mapping.get("to_out_lora.down.weight"), out_rank=rank_mapping.get("to_out_lora.down.weight"),
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"), out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
# rank=rank_mapping.get("to_k_lora.down.weight", None),
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
) )
else: else:
attn_processors[key] = attn_processor_class( attn_processors[key] = attn_processor_class(
...@@ -496,9 +497,6 @@ class UNet2DConditionLoadersMixin: ...@@ -496,9 +497,6 @@ class UNet2DConditionLoadersMixin:
# set ff layers # set ff layers
for target_module, lora_layer in non_attn_lora_layers: for target_module, lora_layer in non_attn_lora_layers:
target_module.set_lora_layer(lora_layer) target_module.set_lora_layer(lora_layer)
# It should raise an error if we don't have a set lora here
# if hasattr(target_module, "set_lora_layer"):
# target_module.set_lora_layer(lora_layer)
def save_attn_procs( def save_attn_procs(
self, self,
...@@ -1251,9 +1249,10 @@ class LoraLoaderMixin: ...@@ -1251,9 +1249,10 @@ class LoraLoaderMixin:
keys = list(state_dict.keys()) keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys): if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix)] text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
} }
...@@ -1303,6 +1302,14 @@ class LoraLoaderMixin: ...@@ -1303,6 +1302,14 @@ class LoraLoaderMixin:
].shape[1] ].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
cls._modify_text_encoder( cls._modify_text_encoder(
text_encoder, text_encoder,
lora_scale, lora_scale,
...@@ -1364,12 +1371,13 @@ class LoraLoaderMixin: ...@@ -1364,12 +1371,13 @@ class LoraLoaderMixin:
lora_parameters = [] lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas network_alphas = {} if network_alphas is None else network_alphas
is_network_alphas_populated = len(network_alphas) > 0
for name, attn_module in text_encoder_attn_modules(text_encoder): for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.get(name + ".k.proj.alpha") query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
key_alpha = network_alphas.get(name + ".q.proj.alpha") key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
value_alpha = network_alphas.get(name + ".v.proj.alpha") value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
proj_alpha = network_alphas.get(name + ".out.proj.alpha") out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
attn_module.q_proj = PatchedLoraProjection( attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
...@@ -1387,14 +1395,14 @@ class LoraLoaderMixin: ...@@ -1387,14 +1395,14 @@ class LoraLoaderMixin:
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
attn_module.out_proj = PatchedLoraProjection( attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
) )
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp: if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder): for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.get(name + ".fc1.alpha") fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
fc2_alpha = network_alphas.get(name + ".fc2.alpha") fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
mlp_module.fc1 = PatchedLoraProjection( mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
...@@ -1406,6 +1414,11 @@ class LoraLoaderMixin: ...@@ -1406,6 +1414,11 @@ class LoraLoaderMixin:
) )
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
return lora_parameters return lora_parameters
@classmethod @classmethod
...@@ -1519,10 +1532,6 @@ class LoraLoaderMixin: ...@@ -1519,10 +1532,6 @@ class LoraLoaderMixin:
lora_name_up = lora_name + ".lora_up.weight" lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha" lora_name_alpha = lora_name + ".alpha"
# if lora_name_alpha in state_dict:
# alpha = state_dict.pop(lora_name_alpha).item()
# network_alphas.update({lora_name_alpha: alpha})
if lora_name.startswith("lora_unet_"): if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
......
...@@ -737,7 +737,7 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -737,7 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images ).images
images = images[0, -3:, -3:, -1].flatten() images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392]) expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
...@@ -760,7 +760,7 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -760,7 +760,7 @@ class LoraIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_unload_lora(self): def test_unload_kohya_lora(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain" prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2 num_inference_steps = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment