Commit 1f50bcb2 authored by GoatWu's avatar GoatWu
Browse files

update lora adapter

parent 9d551b87
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"denoising_step_list": [999, 750, 500, 250],
"lora_path": [
"Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
]
}
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"denoising_step_list": [999, 750, 500, 250],
"lora_path": [
"Wan2.1-T2V-14B/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
]
}
......@@ -31,7 +31,6 @@ class WanCausVidModel(WanModel):
use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "causal_model.pt")
if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
......
......@@ -23,7 +23,6 @@ class WanDistillModel(WanModel):
def _load_ckpt(self, use_bf16, skip_bf16):
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
......
......@@ -40,9 +40,6 @@ class WanLoraWrapper:
if lora_name not in self.lora_metadata:
logger.info(f"LoRA {lora_name} not found. Please load it first.")
if hasattr(self.model, "current_lora") and self.model.current_lora:
self.remove_lora()
if not hasattr(self.model, "original_weight_dict"):
logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
return False
......@@ -52,21 +49,35 @@ class WanLoraWrapper:
self._apply_lora_weights(weight_dict, lora_weights, alpha)
self.model._init_weights(weight_dict)
self.model.current_lora = lora_name
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
return True
@torch.no_grad()
def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_pairs = {}
lora_diffs = {}
prefix = "diffusion_model."
def try_lora_pair(key, suffix_a, suffix_b, target_suffix):
if key.endswith(suffix_a):
base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
pair_key = key.replace(suffix_a, suffix_b)
if pair_key in lora_weights:
lora_pairs[base_name] = (key, pair_key)
def try_lora_diff(key, suffix, target_suffix):
if key.endswith(suffix):
base_name = key[len(prefix) :].replace(suffix, target_suffix)
lora_diffs[base_name] = key
for key in lora_weights.keys():
if key.endswith("lora_A.weight") and key.startswith(prefix):
base_name = key[len(prefix) :].replace("lora_A.weight", "weight")
b_key = key.replace("lora_A.weight", "lora_B.weight")
if b_key in lora_weights:
lora_pairs[base_name] = (key, b_key)
if not key.startswith(prefix):
continue
try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight")
try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight")
try_lora_diff(key, "diff", "weight")
try_lora_diff(key, "diff_b", "bias")
applied_count = 0
for name, param in weight_dict.items():
......@@ -79,6 +90,14 @@ class WanLoraWrapper:
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
elif name in lora_diffs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()
name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
param += lora_diff
applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments")
if applied_count == 0:
......@@ -88,30 +107,22 @@ class WanLoraWrapper:
@torch.no_grad()
def remove_lora(self):
if not self.model.current_lora:
logger.info("No LoRA currently applied")
return
logger.info(f"Removing LoRA {self.model.current_lora}...")
logger.info(f"Removing LoRA ...")
restored_count = 0
for k, v in self.override_dict.items():
self.model.original_weight_dict[k] = v.to(self.model.device)
restored_count += 1
logger.info(f"LoRA {self.model.current_lora} removed, restored {restored_count} weights")
logger.info(f"LoRA removed, restored {restored_count} weights")
self.model._init_weights(self.model.original_weight_dict)
torch.cuda.empty_cache()
gc.collect()
if self.model.current_lora and self.model.current_lora in self.lora_metadata:
del self.lora_metadata[self.model.current_lora]
self.lora_metadata = {}
self.override_dict = {}
self.model.current_lora = None
def list_loaded_loras(self):
return list(self.lora_metadata.keys())
def get_current_lora(self):
return self.model.current_lora
......@@ -46,7 +46,6 @@ class WanModel:
self._init_infer_class()
self._init_weights()
self._init_infer()
self.current_lora = None
if config["parallel_attn_type"]:
if config["parallel_attn_type"] == "ulysses":
......
......@@ -12,6 +12,7 @@ from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistill
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.causvid_model import WanCausVidModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
......@@ -30,7 +31,20 @@ class WanCausVidRunner(WanRunner):
self.num_fragments = self.model.config.num_fragments
def load_transformer(self):
return WanCausVidModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path:
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
lora_wrapper = WanLoraWrapper(model)
for lora_path in self.config.lora_path:
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
else:
model = WanCausVidModel(self.config.model_path, self.config, self.init_device)
return model
def set_inputs(self, inputs):
super().set_inputs(inputs)
......
......@@ -24,12 +24,19 @@ class WanDistillRunner(WanRunner):
super().__init__(config)
def load_transformer(self):
model = WanDistillModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path:
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
for lora_path in self.config.lora_path:
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
else:
model = WanDistillModel(self.config.model_path, self.config, self.init_device)
return model
def init_scheduler(self):
......
......@@ -38,9 +38,10 @@ class WanRunner(DefaultRunner):
if self.config.lora_path:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
for lora_path in self.config.lora_path:
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
return model
def load_image_encoder(self):
......
......@@ -33,7 +33,7 @@ python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill.json \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg_lora_rank32.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_distill.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lightx2v_path="/data/lightx2v-dev"
model_path="/data/lightx2v-dev/Wan2.1-T2V-14B/"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......@@ -32,7 +32,7 @@ python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill.json \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
......
#!/bin/bash
# set path and first
lightx2v_path="/data/lightx2v-dev"
model_path="/data/lightx2v-dev/Wan2.1-T2V-14B/"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_lora_rank32.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment