Unverified Commit df0e2a4f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

support latest few-step wan LoRA. (#12541)

* support latest few-step wan LoRA.

* up

* up
parent 303efd2b
......@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
if any("head.head" in k for k in state_dict):
if any("head.head" in k for k in original_state_dict):
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
f"head.head.{lora_up_key}.weight"
)
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
# an identity.
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
if f"head.head.{lora_down_key}.weight" in state_dict:
logger.info(
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
)
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
).T
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment