Commit 83c12f2b authored by PengGao's avatar PengGao Committed by GitHub
Browse files

feat: refactor LoRA handling and add run script (#16)

parent 4eec372d
...@@ -8,20 +8,20 @@ import gc ...@@ -8,20 +8,20 @@ import gc
class WanLoraWrapper: class WanLoraWrapper:
def __init__(self, wan_model): def __init__(self, wan_model):
self.model = wan_model self.model = wan_model
self.lora_dict = {} self.lora_metadata = {}
self.override_dict = {} self.override_dict = {} # On CPU
def load_lora(self, lora_path, lora_name=None): def load_lora(self, lora_path, lora_name=None):
if lora_name is None: if lora_name is None:
lora_name = os.path.basename(lora_path).split(".")[0] lora_name = os.path.basename(lora_path).split(".")[0]
if lora_name in self.lora_dict: if lora_name in self.lora_metadata:
logger.info(f"LoRA {lora_name} already loaded, skipping...") logger.info(f"LoRA {lora_name} already loaded, skipping...")
return lora_name return lora_name
lora_weights = self._load_lora_file(lora_path) self.lora_metadata[lora_name] = {"path": lora_path}
logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")
self.lora_dict[lora_name] = lora_weights
return lora_name return lora_name
def _load_lora_file(self, file_path): def _load_lora_file(self, file_path):
...@@ -36,7 +36,7 @@ class WanLoraWrapper: ...@@ -36,7 +36,7 @@ class WanLoraWrapper:
return tensor_dict return tensor_dict
def apply_lora(self, lora_name, alpha=1.0): def apply_lora(self, lora_name, alpha=1.0):
if lora_name not in self.lora_dict: if lora_name not in self.lora_metadata:
logger.info(f"LoRA {lora_name} not found. Please load it first.") logger.info(f"LoRA {lora_name} not found. Please load it first.")
if hasattr(self.model, "current_lora") and self.model.current_lora: if hasattr(self.model, "current_lora") and self.model.current_lora:
...@@ -46,19 +46,16 @@ class WanLoraWrapper: ...@@ -46,19 +46,16 @@ class WanLoraWrapper:
logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.") logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
return False return False
lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
weight_dict = self.model.original_weight_dict weight_dict = self.model.original_weight_dict
lora_weights = self.lora_dict[lora_name]
self._apply_lora_weights(weight_dict, lora_weights, alpha) self._apply_lora_weights(weight_dict, lora_weights, alpha)
self.model._init_weights(weight_dict)
# 重新加载权重
self.model.pre_weight.load_weights(weight_dict)
self.model.post_weight.load_weights(weight_dict)
self.model.transformer_weights.load_weights(weight_dict)
self.model.current_lora = lora_name self.model.current_lora = lora_name
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
return True return True
@torch.no_grad()
def _apply_lora_weights(self, weight_dict, lora_weights, alpha): def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_pairs = {} lora_pairs = {}
prefix = "diffusion_model." prefix = "diffusion_model."
...@@ -73,6 +70,9 @@ class WanLoraWrapper: ...@@ -73,6 +70,9 @@ class WanLoraWrapper:
applied_count = 0 applied_count = 0
for name, param in weight_dict.items(): for name, param in weight_dict.items():
if name in lora_pairs: if name in lora_pairs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()
name_lora_A, name_lora_B = lora_pairs[name] name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
...@@ -85,6 +85,7 @@ class WanLoraWrapper: ...@@ -85,6 +85,7 @@ class WanLoraWrapper:
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file." "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
) )
@torch.no_grad()
def remove_lora(self): def remove_lora(self):
if not self.model.current_lora: if not self.model.current_lora:
logger.info("No LoRA currently applied") logger.info("No LoRA currently applied")
...@@ -98,19 +99,18 @@ class WanLoraWrapper: ...@@ -98,19 +99,18 @@ class WanLoraWrapper:
logger.info(f"LoRA {self.model.current_lora} removed, restored {restored_count} weights") logger.info(f"LoRA {self.model.current_lora} removed, restored {restored_count} weights")
self.model.pre_weight.load_weights(self.model.original_weight_dict) self.model._init_weights(self.model.original_weight_dict)
self.model.post_weight.load_weights(self.model.original_weight_dict)
self.model.transformer_weights.load_weights(self.model.original_weight_dict)
if self.model.current_lora and self.model.current_lora in self.lora_dict:
del self.lora_dict[self.model.current_lora]
self.override_dict = {}
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
if self.model.current_lora and self.model.current_lora in self.lora_metadata:
del self.lora_metadata[self.model.current_lora]
self.override_dict = {}
self.model.current_lora = None
def list_loaded_loras(self): def list_loaded_loras(self):
return list(self.lora_dict.keys()) return list(self.lora_metadata.keys())
def get_current_lora(self): def get_current_lora(self):
return self.model.current_lora return self.model.current_lora
#!/bin/bash
# set path and first
script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
lightx2v_path="$(dirname "$script_dir")"
model_path=/mnt/aigc/shared_data/cache/huggingface/hub/Wan2.1-I2V-14B-480P
config_path=$model_path/config.json
lora_path=/mnt/aigc/shared_data/wan_quant/lora/toy_zoe_epoch_324.safetensors
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${config_path}" ]; then
echo "Error: config_path is not set. Please set this variable first."
exit 1
fi
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
python -m lightx2v \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--prompt "画面中的物体轻轻向上跃起,变成了外貌相似的毛绒玩具。毛绒玩具有着一双眼睛,它的颜色和之前的一样。然后,它开始跳跃起来。背景保持一致,气氛显得格外俏皮。" \
--infer_steps 40 \
--target_video_length 81 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn3 \
--seed 42 \
--sample_neg_promp "画面过曝,模糊,文字,字幕" \
--config_path $config_path \
--save_video_path ./output_lightx2v_wan_i2v.mp4 \
--sample_guide_scale 5 \
--sample_shift 5 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--lora_path ${lora_path} \
--feature_caching Tea \
--mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --mm_config '{"mm_type": "Default", "weight_auto_quant": true}' \
# --use_ret_steps \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment