Commit 6ac3cee7 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Support dynamic CFG distillation (#100)

* support dynamic cfg for cfg_distill

* reformat files
parent 62d8881a
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"enable_dynamic_cfg": true,
"cfg_scale": 4.0,
"cpu_offload": false,
"denoising_step_list": [999, 750, 500, 250]
}
......@@ -3,6 +3,7 @@ import sys
import torch
import glob
import json
from safetensors import safe_open
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
......@@ -21,11 +22,20 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device)
def _load_ckpt(self, use_bf16, skip_bf16):
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path):
return super()._load_ckpt(use_bf16, skip_bf16)
enable_dynamic_cfg = self.config.get("enable_dynamic_cfg", False)
ckpt_folder = "distill_cfg_models" if enable_dynamic_cfg else "distill_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
return weight_dict
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()}
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/distill_model.pt")
if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
return weight_dict
return weight_dict
return super()._load_ckpt(use_bf16, skip_bf16)
import torch
from .utils import rope_params, sinusoidal_embedding_1d
from .utils import rope_params, sinusoidal_embedding_1d, guidance_scale_embedding
from lightx2v.utils.envs import *
......@@ -20,6 +20,8 @@ class WanPreInfer:
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
self.cfg_scale = config.get("cfg_scale", 4.0)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -60,6 +62,11 @@ class WanPreInfer:
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(0.0, 8.0), target_range=1000.0, dtype=torch.float32).type_as(x)
cfg_embed = weights.cfg_cond_proj.apply(cfg_embed)
embed = embed + cfg_embed
if GET_DTYPE() != "BF16":
embed = weights.time_embedding_0.apply(embed.float())
else:
......
......@@ -170,3 +170,28 @@ def sinusoidal_embedding_1d(dim, position):
if GET_DTYPE() == "BF16":
x = x.to(torch.bfloat16)
return x
def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(0.0, 8.0), target_range=1000.0, dtype=torch.float32):
"""
Args:
timesteps: torch.Tensor: generate embedding vectors at these timesteps
embedding_dim: int: dimension of the embeddings to generate
dtype: data type of the generated embeddings
Returns:
embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
cfg_min, cfg_max = cfg_range
w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1]
w = w * target_range
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1).to(w.device))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
......@@ -56,3 +56,9 @@ class WanPreWeights(WeightModule):
"proj_4",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"),
)
if config.model_cls == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
self.add_module(
"cfg_cond_proj",
MM_WEIGHT_REGISTER["Default"]("cfg_cond_proj.weight", "cfg_cond_proj.bias"),
)
......@@ -67,7 +67,7 @@ if __name__ == "__main__":
{
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt": "A cat walks on the grass, realistic style.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t01.mp4", # It is best to set it to an absolute path.
......@@ -75,7 +75,7 @@ if __name__ == "__main__":
{
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt": "A person is riding a bike. Realistic, Natural lighting, Casual.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", # It is best to set it to an absolute path.
......@@ -83,7 +83,7 @@ if __name__ == "__main__":
{
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt": "A car turns a corner. Realistic, Natural lighting, Casual.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t03.mp4", # It is best to set it to an absolute path.
......@@ -91,7 +91,7 @@ if __name__ == "__main__":
{
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt": "An astronaut is flying in space, Van Gogh style. Dark, Mysterious.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t04.mp4", # It is best to set it to an absolute path.
......@@ -99,11 +99,19 @@ if __name__ == "__main__":
{
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"prompt": "A beautiful coastal beach in spring, waves gently lapping on the sand, the camera movement is Zoom In. Realistic, Natural lighting, Peaceful.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t05.mp4", # It is best to set it to an absolute path.
},
{
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t06.mp4", # It is best to set it to an absolute path.
},
]
logger.info(f"urls: {urls}")
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_t2v_distill_4step_cfg_dynamic.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--use_prompt_enhancer \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_cfg_4.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment