Unverified Commit a34d97ce authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Wan LoRAs] make T2V LoRAs compatible with Wan I2V (#11107)



* @hlky t2v->i2v

* Apply style fixes

* try with ones to not nullify layers

* fix method name

* revert to zeros

* add check to state_dict keys

* add comment

* copies fix

* Revert "copies fix"

This reverts commit 051f534d185c0ea065bf36a9926c4b48f496d429.

* remove copied from

* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* update

* update

* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarLinoy <linoy@hf.co>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent fc28791f
......@@ -4249,7 +4249,33 @@ class WanLoraLoaderMixin(LoraBaseMixin):
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
@classmethod
def _maybe_expand_t2v_lora_for_i2v(
cls,
transformer: torch.nn.Module,
state_dict,
):
if transformer.config.image_dim is None:
return state_dict
if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
if is_i2v_lora:
return state_dict
for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
)
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
)
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
......@@ -4287,7 +4313,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
state_dict = self._maybe_expand_t2v_lora_for_i2v(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict=state_dict,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment