Unverified Commit 8d1de408 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora (#12074)



* add alpha

* load into 2nd transformer

* Update src/diffusers/loaders/lora_conversion_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders/lora_conversion_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* pr comments

* pr comments

* pr comments

* fix

* fix

* Apply style fixes

* fix copies

* fix

* fix copies

* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* revert change

* revert change

* fix copies

* up

* fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarlinoy <linoy@hf.co>
parent 8cc528c5
...@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip ...@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more. - Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
## WanPipeline ## WanPipeline
[[autodoc]] WanPipeline [[autodoc]] WanPipeline
......
...@@ -754,7 +754,11 @@ class LoraBaseMixin: ...@@ -754,7 +754,11 @@ class LoraBaseMixin:
# Decompose weights into weights for denoiser and text encoders. # Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {} _component_adapter_weights = {}
for component in self._lora_loadable_modules: for component in self._lora_loadable_modules:
model = getattr(self, component) model = getattr(self, component, None)
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
# Whereas in Wan 2.2, we have two denoisers.
if model is None:
continue
for adapter_name, weights in zip(adapter_names, adapter_weights): for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict): if isinstance(weights, dict):
......
...@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
) )
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = original_state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
if key.endswith((".diff", ".diff_b")) and "norm" in key: if key.endswith((".diff", ".diff_b")) and "norm" in key:
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
...@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
for i in range(min_block, max_block + 1): for i in range(min_block, max_block + 1):
# Self-attention # Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" has_alpha = alpha_key in original_state_dict
if original_key in original_state_dict: original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_state_dict[converted_key] = original_state_dict.pop(original_key) converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key) if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] = down_weight * scale_down
converted_state_dict[converted_key_B] = up_weight * scale_up
else:
if original_key_A in original_state_dict:
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
if original_key_B in original_state_dict:
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
original_key = f"blocks.{i}.self_attn.{o}.diff_b" original_key = f"blocks.{i}.self_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
...@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# Cross-attention # Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" has_alpha = alpha_key in original_state_dict
if original_key in original_state_dict: original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_state_dict[converted_key] = original_state_dict.pop(original_key) converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key) if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b" original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
...@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if is_i2v_lora: if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" has_alpha = alpha_key in original_state_dict
if original_key in original_state_dict: original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_state_dict[converted_key] = original_state_dict.pop(original_key) converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key) if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b" original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
...@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# FFN # FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" alpha_key = f"blocks.{i}.{o}.alpha"
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight" has_alpha = alpha_key in original_state_dict
if original_key in original_state_dict: original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
converted_state_dict[converted_key] = original_state_dict.pop(original_key) converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight" converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key) if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.{o}.diff_b" original_key = f"blocks.{i}.{o}.diff_b"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
......
...@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
""" """
_lora_loadable_modules = ["transformer"] _lora_loadable_modules = ["transformer", "transformer_2"]
transformer_name = TRANSFORMER_NAME transformer_name = TRANSFORMER_NAME
@classmethod @classmethod
...@@ -5270,9 +5270,29 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -5270,9 +5270,29 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
if not hasattr(self, "transformer_2"):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute transformer_2"
"Note that Wan2.1 models do not have a transformer_2 component."
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
)
self.load_lora_into_transformer( self.load_lora_into_transformer(
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=self.transformer_2,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata, metadata=metadata,
_pipeline=self, _pipeline=self,
...@@ -5668,9 +5688,29 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): ...@@ -5668,9 +5688,29 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
if not hasattr(self, "transformer_2"):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute transformer_2"
"Note that Wan2.1 models do not have a transformer_2 component."
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
)
self.load_lora_into_transformer( self.load_lora_into_transformer(
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=self.transformer_2,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata, metadata=metadata,
_pipeline=self, _pipeline=self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment