"vscode:/vscode.git/clone" did not exist on "ddb8577d8ca1fe9949a94f4a0d2b3328cb2a471b"
Commit 2d823f25 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

bug fixed for lora tools (#89)

* update lora keys

* update lora extractor/merger tools

* rename lora config and script files

* bug fixed for lora tools
parent 861b2e7d
......@@ -13,6 +13,6 @@
"cpu_offload": false,
"denoising_step_list": [999, 750, 500, 250],
"lora_path": [
"Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
"Wan2.1-I2V-14B-480P/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
]
}
......@@ -89,16 +89,18 @@ class WanLoraWrapper:
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
if param.shape == (lora_B.shape[0], lora_A.shape[1]):
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
elif name in lora_diffs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()
name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
param += lora_diff * alpha
applied_count += 1
if param.shape == lora_diff.shape:
param += lora_diff * alpha
applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments")
if applied_count == 0:
......
......@@ -33,7 +33,7 @@ python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg_lora_rank32.json \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg_lora.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
......
......@@ -32,7 +32,7 @@ python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora_rank32.json \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
......
......@@ -274,8 +274,8 @@ def _decompose_to_lora(diff: torch.Tensor, key: str, rank: int) -> Dict[str, tor
# Generate LoRA weight key names
base_key = key.replace(".weight", "")
lora_up_key = "diffusion_model." + f"{base_key}.lora_up"
lora_down_key = "diffusion_model." + f"{base_key}.lora_down"
lora_up_key = "diffusion_model." + f"{base_key}.lora_up.weight"
lora_down_key = "diffusion_model." + f"{base_key}.lora_down.weight"
# Return the decomposed weights
lora_weights = {lora_up_key: lora_up, lora_down_key: lora_down}
......@@ -355,6 +355,9 @@ def extract_lora_from_diff(source_weights: Dict[str, torch.Tensor], target_weigh
if diff_only or is_1d or param_count < 1000000:
# Save diff directly
lora_key = _generate_lora_diff_key(key)
if lora_key == "skip":
skipped_count += 1
continue
lora_weights[lora_key] = diff
diff_count += 1
......@@ -401,15 +404,16 @@ def _generate_lora_diff_key(original_key: str) -> str:
Returns:
LoRA weight key name
"""
ret_key = "diffusion_model." + original_key
if original_key.endswith(".weight"):
return original_key.replace(".weight", ".diff")
return ret_key.replace(".weight", ".diff")
elif original_key.endswith(".bias"):
return original_key.replace(".bias", ".diff_b")
return ret_key.replace(".bias", ".diff_b")
elif original_key.endswith(".modulation"):
return original_key.replace(".modulation", ".diff_m")
return ret_key.replace(".modulation", ".diff_m")
else:
# If no matching suffix, directly add .diff
return "diffusion_model." + original_key + ".diff"
# If no matching suffix, skip
return "skip"
def main():
......
......@@ -226,13 +226,13 @@ def merge_lora_weights(source_weights: Dict[str, torch.Tensor], lora_weights: Di
diff_weights = {}
for lora_key, lora_tensor in lora_weights.items():
if lora_key.endswith(".lora_up"):
base_key = lora_key.replace(".lora_up", "")
if lora_key.endswith(".lora_up.weight"):
base_key = lora_key.replace(".lora_up.weight", "")
if base_key not in lora_pairs:
lora_pairs[base_key] = {}
lora_pairs[base_key]["up"] = lora_tensor
elif lora_key.endswith(".lora_down"):
base_key = lora_key.replace(".lora_down", "")
elif lora_key.endswith(".lora_down.weight"):
base_key = lora_key.replace(".lora_down.weight", "")
if base_key not in lora_pairs:
lora_pairs[base_key] = {}
lora_pairs[base_key]["down"] = lora_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment