Unverified Commit 425a715e authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Fix Wan AccVideo/CausVid fuse_lora (#11856)

* fix

* actually, better fix

* empty commit; trigger tests again

* mark wanvace test as flaky
parent 25279175
...@@ -1825,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1825,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
has_time_projection_weight = any(
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
)
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))] for key in list(original_state_dict.keys()):
if diff_keys: if key.endswith((".diff", ".diff_b")) and "norm" in key:
for diff_k in diff_keys: # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
param = original_state_dict[diff_k] # in future if needed and they are not zeroed.
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3, original_state_dict.pop(key)
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
# is okay to ignore because they do not affect the model output in a significant manner. if "time_projection" in key and not has_time_projection_weight:
threshold = 1.6e-2 # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
absdiff = param.abs().max() - param.abs().min() # our lora config adds the time proj lora layers, but we don't have the weights for them.
all_zero = torch.all(param == 0).item() # CausVid lora has the weight keys and the bias keys.
all_absdiff_lower_than_threshold = absdiff < threshold original_state_dict.pop(key)
if all_zero or all_absdiff_lower_than_threshold:
logger.debug(
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
)
original_state_dict.pop(diff_k)
# For the `diff_b` keys, we treat them as lora_bias. # For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
......
...@@ -28,6 +28,7 @@ from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACE ...@@ -28,6 +28,7 @@ from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACE
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
is_flaky,
require_peft_backend, require_peft_backend,
require_peft_version_greater, require_peft_version_greater,
skip_mps, skip_mps,
...@@ -215,3 +216,7 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -215,3 +216,7 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.", "Lora outputs should match.",
) )
@is_flaky
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment