Unverified Commit 5d4f723b authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[LoRA] kijai wan lora support for I2V (#11588)



* testing

* testing

* testing

* testing

* testing

* i2v

* i2v

* device fix

* testing

* fix

* fix

* fix

* fix

* fix

* Apply style fixes

* empty commit

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 05c8b42b
...@@ -4813,20 +4813,41 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4813,20 +4813,41 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if transformer.config.image_dim is None: if transformer.config.image_dim is None:
return state_dict return state_dict
target_device = transformer.device
if any(k.startswith("transformer.blocks.") for k in state_dict): if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
has_bias = any(".lora_B.bias" in k for k in state_dict)
if is_i2v_lora: if is_i2v_lora:
return state_dict return state_dict
for i in range(num_blocks): for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
# These keys should exist if the block `i` was part of the T2V LoRA.
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
continue
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
) )
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
)
# If the original LoRA had biases (indicated by has_bias)
# AND the specific reference bias key exists for this block.
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
if has_bias and ref_key_lora_B_bias in state_dict:
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
ref_lora_B_bias_tensor,
device=target_device,
) )
return state_dict return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment