"...text-generation-inference.git" did not exist on "7fbfbb0dc5d9f928ee4c496a04bff4a0eca8a9c8"
Unverified Commit e4b056fe authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.
parent 4e3ddd5a
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
# Wan # Wan
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team. [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
<!-- TODO(aryan): update abstract once paper is out --> <!-- TODO(aryan): update abstract once paper is out -->
......
...@@ -1348,3 +1348,53 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict): ...@@ -1348,3 +1348,53 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict return converted_state_dict
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
for i in range(num_blocks):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_B.weight"
)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_B.weight"
)
if len(original_state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
...@@ -42,6 +42,7 @@ from .lora_conversion_utils import ( ...@@ -42,6 +42,7 @@ from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers, _maybe_map_sgm_blocks_to_diffusers,
) )
...@@ -4111,7 +4112,6 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4111,7 +4112,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict( def lora_state_dict(
cls, cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
...@@ -4198,6 +4198,8 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4198,6 +4198,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent, user_agent=user_agent,
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present: if is_dora_scale_present:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment