Unverified Commit 124ac3e8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.
parent 2f0f281b
...@@ -1355,6 +1355,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1355,6 +1355,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
for i in range(num_blocks): for i in range(num_blocks):
# Self-attention # Self-attention
...@@ -1374,13 +1375,15 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1374,13 +1375,15 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight" f"blocks.{i}.cross_attn.{o}.lora_B.weight"
) )
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( if is_i2v_lora:
f"blocks.{i}.cross_attn.{o}.lora_A.weight" for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
) converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( f"blocks.{i}.cross_attn.{o}.lora_A.weight"
f"blocks.{i}.cross_attn.{o}.lora_B.weight" )
) converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
# FFN # FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment