Unverified Commit f9fd5114 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] support Kohya Flux LoRAs that have text encoders as well (#9542)

* support kohya flux loras that have tes.
parent 8e7d6c03
...@@ -516,10 +516,47 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): ...@@ -516,10 +516,47 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
f"transformer.single_transformer_blocks.{i}.norm.linear", f"transformer.single_transformer_blocks.{i}.norm.linear",
) )
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
continue
lora_name = key.split(".")[0]
lora_name_up = f"{lora_name}.lora_up.weight"
lora_name_alpha = f"{lora_name}.alpha"
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
if lora_name.startswith(("lora_te_", "lora_te1_")):
down_weight = sds_sd.pop(key)
sd_lora_rank = down_weight.shape[0]
te_state_dict[diffusers_name] = down_weight
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
if lora_name_alpha in sds_sd:
alpha = sds_sd.pop(lora_name_alpha).item()
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
te_state_dict[diffusers_name] *= scale_down
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
if len(sds_sd) > 0: if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
return ait_sd if te_state_dict:
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict
return _convert_sd_scripts_to_ai_toolkit(state_dict) return _convert_sd_scripts_to_ai_toolkit(state_dict)
......
...@@ -228,6 +228,26 @@ class FluxLoRAIntegrationTests(unittest.TestCase): ...@@ -228,6 +228,26 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_flux_xlabs(self): def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora() self.pipeline.fuse_lora()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment