Unverified Commit 7e7e62c6 authored by Dave Lage's avatar Dave Lage Committed by GitHub
Browse files

Convert alphas for embedders for sd-scripts to ai toolkit conversion (#12332)



* Convert alphas for embedders for sd-scripts to ai toolkit conversion

* Add kohya embedders conversion test

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent eda9ff83
......@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_in_layer",
"time_text_embed.guidance_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_guidance_in_out_layer",
"time_text_embed.guidance_embedder.linear_2",
)
if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_img_in",
"x_embedder",
)
if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_txt_in",
"context_embedder",
)
if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_in_layer",
"time_text_embed.timestep_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_time_in_out_layer",
"time_text_embed.timestep_embedder.linear_2",
)
if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_in_layer",
"time_text_embed.text_embedder.linear_1",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
"lora_unet_vector_in_out_layer",
"time_text_embed.text_embedder.linear_2",
)
if any("final_layer" in k for k in sds_sd):
......
......@@ -907,6 +907,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3
def test_flux_kohya_embedders_conversion(self):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()
assert True
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment