Commit b2147c40 authored by wangshankun's avatar wangshankun
Browse files

Apply: distll lora

parent d02b97a7
{
"infer_steps": 20,
"infer_steps": 8,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":5.0,
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
"cpu_offload": false,
"feature_caching": "Tea",
"coefficients": [
[8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.12
}
......@@ -50,6 +50,7 @@ async def main():
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
......
......@@ -434,17 +434,10 @@ class CLIPModel:
class WanVideoIPHandler:
def __init__(self,
model_name,
repo_or_path,
require_grad=False,
mode='eval',
device='cuda',
dtype=torch.float16):
def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
''' 720P-I2V-diffusers config is
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
......@@ -455,14 +448,11 @@ class WanVideoIPHandler:
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
'''
image_encoder = CLIPVisionModel.from_pretrained(
repo_or_path, subfolder='image_encoder', torch_dtype=dtype)
logger.info(
f'Using image encoder {model_name} from {repo_or_path}'
)
"""
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, subfolder="image_encoder", torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad)
if mode == 'eval':
if mode == "eval":
image_encoder.eval()
else:
image_encoder.train()
......@@ -482,16 +472,10 @@ class WanVideoIPHandler:
if img_tensor.ndim == 5: # B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
img_tensor = torch.clamp(
img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(
img_tensor, size=self.size, mode='bicubic', align_corners=False)
img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
img_tensor = self.normalize(img_tensor).to(self.dtype)
logger.info(
f'Image tensor shape after processing: {img_tensor}')
image_embeds = self.image_encoder(
pixel_values=img_tensor, output_hidden_states=True)
logger.info(
f'Image embeds : {image_embeds.hidden_states[-1]}')
return image_embeds.hidden_states[-1]
\ No newline at end of file
image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)
return image_embeds.hidden_states[-1]
......@@ -84,7 +84,8 @@ class WanLoraWrapper:
if name in lora_pairs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()
# import pdb
# pdb.set_trace()
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
......
......@@ -264,17 +264,20 @@ class WanAudioRunner(WanRunner):
def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
return base_model
def load_image_encoder(self):
image_encoder = WanVideoIPHandler(
"CLIPModel",
repo_or_path="/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers",
require_grad=False,
mode='eval',
device=self.init_device,
dtype=torch.float16)
"CLIPModel", repo_or_path="/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers", require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16
)
return image_encoder
......
......@@ -3,7 +3,7 @@
# set path and first
lightx2v_path="/mnt/Text2Video/wangshankun/lightx2v"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-Audio-14B-720P/"
lora_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
......@@ -37,4 +37,5 @@ python -m lightx2v.infer \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4 \
--lora_path ${lora_path}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment