Commit 2d823f25 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

bug fixed for lora tools (#89)

* update lora keys

* update lora extractor/merger tools

* rename lora config and script files

* bug fixed for lora tools
parent 861b2e7d
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
"cpu_offload": false, "cpu_offload": false,
"denoising_step_list": [999, 750, 500, 250], "denoising_step_list": [999, 750, 500, 250],
"lora_path": [ "lora_path": [
"Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors" "Wan2.1-I2V-14B-480P/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
] ]
} }
...@@ -89,16 +89,18 @@ class WanLoraWrapper: ...@@ -89,16 +89,18 @@ class WanLoraWrapper:
name_lora_A, name_lora_B = lora_pairs[name] name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha if param.shape == (lora_B.shape[0], lora_A.shape[1]):
applied_count += 1 param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
elif name in lora_diffs: elif name in lora_diffs:
if name not in self.override_dict: if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu() self.override_dict[name] = param.clone().cpu()
name_diff = lora_diffs[name] name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype) lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
param += lora_diff * alpha if param.shape == lora_diff.shape:
applied_count += 1 param += lora_diff * alpha
applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments") logger.info(f"Applied {applied_count} LoRA weight adjustments")
if applied_count == 0: if applied_count == 0:
......
...@@ -33,7 +33,7 @@ python -m lightx2v.infer \ ...@@ -33,7 +33,7 @@ python -m lightx2v.infer \
--model_cls wan2.1_distill \ --model_cls wan2.1_distill \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg_lora_rank32.json \ --config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg_lora.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ --negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ --image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
......
...@@ -32,7 +32,7 @@ python -m lightx2v.infer \ ...@@ -32,7 +32,7 @@ python -m lightx2v.infer \
--model_cls wan2.1_distill \ --model_cls wan2.1_distill \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora_rank32.json \ --config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \ --use_prompt_enhancer \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ --negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
......
...@@ -274,8 +274,8 @@ def _decompose_to_lora(diff: torch.Tensor, key: str, rank: int) -> Dict[str, tor ...@@ -274,8 +274,8 @@ def _decompose_to_lora(diff: torch.Tensor, key: str, rank: int) -> Dict[str, tor
# Generate LoRA weight key names # Generate LoRA weight key names
base_key = key.replace(".weight", "") base_key = key.replace(".weight", "")
lora_up_key = "diffusion_model." + f"{base_key}.lora_up" lora_up_key = "diffusion_model." + f"{base_key}.lora_up.weight"
lora_down_key = "diffusion_model." + f"{base_key}.lora_down" lora_down_key = "diffusion_model." + f"{base_key}.lora_down.weight"
# Return the decomposed weights # Return the decomposed weights
lora_weights = {lora_up_key: lora_up, lora_down_key: lora_down} lora_weights = {lora_up_key: lora_up, lora_down_key: lora_down}
...@@ -355,6 +355,9 @@ def extract_lora_from_diff(source_weights: Dict[str, torch.Tensor], target_weigh ...@@ -355,6 +355,9 @@ def extract_lora_from_diff(source_weights: Dict[str, torch.Tensor], target_weigh
if diff_only or is_1d or param_count < 1000000: if diff_only or is_1d or param_count < 1000000:
# Save diff directly # Save diff directly
lora_key = _generate_lora_diff_key(key) lora_key = _generate_lora_diff_key(key)
if lora_key == "skip":
skipped_count += 1
continue
lora_weights[lora_key] = diff lora_weights[lora_key] = diff
diff_count += 1 diff_count += 1
...@@ -401,15 +404,16 @@ def _generate_lora_diff_key(original_key: str) -> str: ...@@ -401,15 +404,16 @@ def _generate_lora_diff_key(original_key: str) -> str:
Returns: Returns:
LoRA weight key name LoRA weight key name
""" """
ret_key = "diffusion_model." + original_key
if original_key.endswith(".weight"): if original_key.endswith(".weight"):
return original_key.replace(".weight", ".diff") return ret_key.replace(".weight", ".diff")
elif original_key.endswith(".bias"): elif original_key.endswith(".bias"):
return original_key.replace(".bias", ".diff_b") return ret_key.replace(".bias", ".diff_b")
elif original_key.endswith(".modulation"): elif original_key.endswith(".modulation"):
return original_key.replace(".modulation", ".diff_m") return ret_key.replace(".modulation", ".diff_m")
else: else:
# If no matching suffix, directly add .diff # If no matching suffix, skip
return "diffusion_model." + original_key + ".diff" return "skip"
def main(): def main():
......
...@@ -226,13 +226,13 @@ def merge_lora_weights(source_weights: Dict[str, torch.Tensor], lora_weights: Di ...@@ -226,13 +226,13 @@ def merge_lora_weights(source_weights: Dict[str, torch.Tensor], lora_weights: Di
diff_weights = {} diff_weights = {}
for lora_key, lora_tensor in lora_weights.items(): for lora_key, lora_tensor in lora_weights.items():
if lora_key.endswith(".lora_up"): if lora_key.endswith(".lora_up.weight"):
base_key = lora_key.replace(".lora_up", "") base_key = lora_key.replace(".lora_up.weight", "")
if base_key not in lora_pairs: if base_key not in lora_pairs:
lora_pairs[base_key] = {} lora_pairs[base_key] = {}
lora_pairs[base_key]["up"] = lora_tensor lora_pairs[base_key]["up"] = lora_tensor
elif lora_key.endswith(".lora_down"): elif lora_key.endswith(".lora_down.weight"):
base_key = lora_key.replace(".lora_down", "") base_key = lora_key.replace(".lora_down.weight", "")
if base_key not in lora_pairs: if base_key not in lora_pairs:
lora_pairs[base_key] = {} lora_pairs[base_key] = {}
lora_pairs[base_key]["down"] = lora_tensor lora_pairs[base_key]["down"] = lora_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment