Unverified Commit e52ceae3 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Support Wan AccVideo lora (#11704)

* update

* make style

* Update src/diffusers/loaders/lora_conversion_utils.py

* add note explaining threshold
parent 62cbde8d
...@@ -1605,9 +1605,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1605,9 +1605,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if diff_keys: if diff_keys:
for diff_k in diff_keys: for diff_k in diff_keys:
param = original_state_dict[diff_k] param = original_state_dict[diff_k]
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
# is okay to ignore because they do not affect the model output in a significant manner.
threshold = 1.6e-2
absdiff = param.abs().max() - param.abs().min()
all_zero = torch.all(param == 0).item() all_zero = torch.all(param == 0).item()
if all_zero: all_absdiff_lower_than_threshold = absdiff < threshold
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.") if all_zero or all_absdiff_lower_than_threshold:
logger.debug(
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
)
original_state_dict.pop(diff_k) original_state_dict.pop(diff_k)
# For the `diff_b` keys, we treat them as lora_bias. # For the `diff_b` keys, we treat them as lora_bias.
...@@ -1655,12 +1664,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1655,12 +1664,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# FFN # FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
f"blocks.{i}.{o}.{lora_down_key}.weight" converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
) if original_key in original_state_dict:
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( converted_state_dict[converted_key] = original_state_dict.pop(original_key)
f"blocks.{i}.{o}.{lora_up_key}.weight"
) original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if f"blocks.{i}.{o}.diff_b" in original_state_dict: if f"blocks.{i}.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop( converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.{o}.diff_b" f"blocks.{i}.{o}.diff_b"
...@@ -1669,12 +1682,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1669,12 +1682,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# Remaining. # Remaining.
if original_state_dict: if original_state_dict:
if any("time_projection" in k for k in original_state_dict): if any("time_projection" in k for k in original_state_dict):
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop( original_key = f"time_projection.1.{lora_down_key}.weight"
f"time_projection.1.{lora_down_key}.weight" converted_key = "condition_embedder.time_proj.lora_A.weight"
) if original_key in original_state_dict:
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop( converted_state_dict[converted_key] = original_state_dict.pop(original_key)
f"time_projection.1.{lora_up_key}.weight"
) original_key = f"time_projection.1.{lora_up_key}.weight"
converted_key = "condition_embedder.time_proj.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if "time_projection.1.diff_b" in original_state_dict: if "time_projection.1.diff_b" in original_state_dict:
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop( converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
"time_projection.1.diff_b" "time_projection.1.diff_b"
...@@ -1709,6 +1726,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): ...@@ -1709,6 +1726,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
original_state_dict.pop(f"{text_time}.{b_n}.diff_b") original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
) )
for img_ours, img_theirs in [
("ff.net.0.proj", "img_emb.proj.1"),
("ff.net.2", "img_emb.proj.3"),
]:
original_key = f"{img_theirs}.{lora_down_key}.weight"
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"{img_theirs}.{lora_up_key}.weight"
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0: if len(original_state_dict) > 0:
diff = all(".diff" in k for k in original_state_dict) diff = all(".diff" in k for k in original_state_dict)
if diff: if diff:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment