"docs/ZH_CN/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "d0ec1aaae084bc83dc5f08f387e24bb231107a39"
Commit 3c3aa562 authored by wangshankun's avatar wangshankun
Browse files

lora使用多个prefix

parent 0fcc5842
...@@ -53,29 +53,33 @@ class WanLoraWrapper: ...@@ -53,29 +53,33 @@ class WanLoraWrapper:
def _apply_lora_weights(self, weight_dict, lora_weights, alpha): def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_pairs = {} lora_pairs = {}
lora_diffs = {} lora_diffs = {}
prefix = "diffusion_model."
def try_lora_pair(key, suffix_a, suffix_b, target_suffix): def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix):
if key.endswith(suffix_a): if key.endswith(suffix_a):
base_name = key[len(prefix) :].replace(suffix_a, target_suffix) base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
pair_key = key.replace(suffix_a, suffix_b) pair_key = key.replace(suffix_a, suffix_b)
if pair_key in lora_weights: if pair_key in lora_weights:
lora_pairs[base_name] = (key, pair_key) lora_pairs[base_name] = (key, pair_key)
def try_lora_diff(key, suffix, target_suffix): def try_lora_diff(key, prefix, suffix, target_suffix):
if key.endswith(suffix): if key.endswith(suffix):
base_name = key[len(prefix) :].replace(suffix, target_suffix) base_name = key[len(prefix) :].replace(suffix, target_suffix)
lora_diffs[base_name] = key lora_diffs[base_name] = key
for key in lora_weights.keys(): prefixs = [
if not key.startswith(prefix): "", # empty prefix
continue "diffusion_model.",
]
try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight") for prefix in prefixs:
try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight") for key in lora_weights.keys():
try_lora_diff(key, "diff", "weight") if not key.startswith(prefix):
try_lora_diff(key, "diff_b", "bias") continue
try_lora_diff(key, "diff_m", "modulation")
try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight")
try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight")
try_lora_diff(key, prefix, "diff", "weight")
try_lora_diff(key, prefix, "diff_b", "bias")
try_lora_diff(key, prefix, "diff_m", "modulation")
applied_count = 0 applied_count = 0
for name, param in weight_dict.items(): for name, param in weight_dict.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment