Commit 56af41eb authored by helloyongyang's avatar helloyongyang
Browse files

Big Refactor

parent 142f6872
......@@ -3,22 +3,22 @@ from ..scheduler import WanScheduler
class WanSchedulerTeaCaching(WanScheduler):
def __init__(self, args):
super().__init__(args)
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.num_steps = self.args.infer_steps * 2
self.teacache_thresh = self.args.teacache_thresh
self.num_steps = self.config.infer_steps * 2
self.teacache_thresh = self.config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ret_steps = self.args.use_ret_steps
self.use_ret_steps = self.config.use_ret_steps
if self.args.task == "i2v":
if self.config.task == "i2v":
if self.use_ret_steps:
if self.args.target_width == 480 or self.args.target_height == 480:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
......@@ -26,7 +26,7 @@ class WanSchedulerTeaCaching(WanScheduler):
-1.35890334e01,
1.32517977e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
......@@ -35,9 +35,9 @@ class WanSchedulerTeaCaching(WanScheduler):
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if self.args.target_width == 480 or self.args.target_height == 480:
if self.config.target_width == 480 or self.config.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
......@@ -45,7 +45,7 @@ class WanSchedulerTeaCaching(WanScheduler):
5.87348440e00,
-2.01973289e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
if self.config.target_width == 720 or self.config.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
......@@ -54,23 +54,23 @@ class WanSchedulerTeaCaching(WanScheduler):
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
elif self.args.task == "t2v":
elif self.config.task == "t2v":
if self.use_ret_steps:
if "1.3B" in self.args.model_path:
if "1.3B" in self.config.model_path:
self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if "14B" in self.args.model_path:
if "14B" in self.config.model_path:
self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
self.cutoff_steps = self.config.infer_steps * 2
else:
if "1.3B" in self.args.model_path:
if "1.3B" in self.config.model_path:
self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if "14B" in self.args.model_path:
if "14B" in self.config.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
def clear(self):
if self.previous_e0_even is not None:
......
......@@ -6,25 +6,28 @@ from lightx2v.models.schedulers.scheduler import BaseScheduler
class WanScheduler(BaseScheduler):
def __init__(self, args):
super().__init__(args)
def __init__(self, config):
super().__init__(config)
self.device = torch.device("cuda")
self.infer_steps = self.args.infer_steps
self.target_video_length = self.args.target_video_length
self.sample_shift = self.args.sample_shift
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.shift = 1
self.num_train_timesteps = 1000
self.disable_corrector = []
self.solver_order = 2
self.noise_pred = None
def prepare(self, image_encoder_output):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.args.seed)
self.prepare_latents(self.args.target_shape, dtype=torch.float32)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.args.task in ["t2v"]:
self.seq_len = math.ceil((self.args.target_shape[2] * self.args.target_shape[3]) / (self.args.patch_size[1] * self.args.patch_size[2]) * self.args.target_shape[1])
elif self.args.task in ["i2v"]:
self.seq_len = ((self.args.target_video_length - 1) // self.args.vae_stride[0] + 1) * args.lat_h * args.lat_w // (args.patch_size[1] * args.patch_size[2])
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
......
......@@ -49,3 +49,5 @@ RMS_WEIGHT_REGISTER = Register()
LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register()
RUNNER_REGISTER = Register()
......@@ -3,20 +3,37 @@ import os
from easydict import EasyDict
def get_default_config():
default_config = {
"do_mm_calib": False,
"cpu_offload": False,
"parallel_attn_type": None, # [None, "ulysses", "ring"]
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"feature_caching": "NoCaching", # ["NoCaching", "TaylorSeer", "Tea"]
"teacache_thresh": 0.26,
"use_ret_steps": False,
"use_bfloat16": True,
"lora_path": None,
"strength_model": 1.0,
}
return default_config
def set_config(args):
config = {k: v for k, v in vars(args).items()}
config = get_default_config()
config.update({k: v for k, v in vars(args).items()})
config = EasyDict(config)
if args.mm_config:
config.mm_config = json.loads(args.mm_config)
else:
config.mm_config = None
with open(args.config_json, "r") as f:
config_json = json.load(f)
config.update(config_json)
try:
if os.path.exists(os.path.join(args.model_path, "config.json")):
with open(os.path.join(args.model_path, "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
except Exception as e:
print(e)
return config
......@@ -6,7 +6,7 @@ model_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
cuda_devices=6
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......@@ -26,16 +26,11 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
python -m lightx2v \
--model_cls hunyuan \
--model_path $model_path \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_i2v.json \
--prompt "An Asian man with short hair in black tactical uniform and white clothes waves a firework stick." \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_1.jpg \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--save_video_path ./output_lightx2v_hy_i2v.mp4 \
--seed 0
--save_video_path ./output_lightx2v_hy_i2v.mp4
......@@ -6,7 +6,7 @@ model_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
cuda_devices=6
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......@@ -26,14 +26,10 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
python -m lightx2v \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_t2v.json \
--prompt "A cat walks on the grass, realistic style." \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn3 \
--save_video_path ./output_lightx2v_hy_t2v.mp4 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
--save_video_path ./output_lightx2v_hy_t2v.mp4
......@@ -27,26 +27,16 @@ export ENABLE_PROFILING_DEBUG=true
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist/hunyuan_t2v_dist_ulysses.json \
--prompt "A cat walks on the grass, realistic style." \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn_type ulysses \
--save_video_path ./output_lightx2v_hunyuan_t2v_dist_ulysses.mp4
torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/dist/hunyuan_t2v_dist_ring.json \
--prompt "A cat walks on the grass, realistic style." \
--infer_steps 20 \
--target_video_length 33 \
--target_height 720 \
--target_width 1280 \
--attention_type flash_attn2 \
--mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
--parallel_attn_type ring \
--save_video_path ./output_lightx2v_hunyuan_t2v_dist_ring.mp4
......@@ -25,16 +25,10 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
python ${lightx2v_path}/lightx2v/__main__.py \
python -m lightx2v \
--model_cls hunyuan \
--task t2v \
--model_path $model_path \
--prompt "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." \
--infer_steps 50 \
--target_video_length 65 \
--target_height 480 \
--target_width 640 \
--attention_type flash_attn2 \
--cpu_offload \
--feature_caching TaylorSeer \
--save_video_path ./output_lightx2v_hy_t2v.mp4 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}'
--config_json ${lightx2v_path}/configs/caching/hunyuan_t2v_TaylorSeer.json \
--prompt "A cat walks on the grass, realistic style." \
--save_video_path ./output_lightx2v_hy_t2v.mp4
......@@ -6,7 +6,7 @@ model_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
cuda_devices=6
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......@@ -26,22 +26,12 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
python -m lightx2v \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--infer_steps 40 \
--target_video_length 81 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn3 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ./output_lightx2v_wan_i2v.mp4 \
--sample_guide_scale 5 \
--sample_shift 5 \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--mm_config '{"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
--save_video_path ./output_lightx2v_wan_i2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=''
model_path=''
lora_path=''
lightx2v_path=""
model_path=""
lora_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......@@ -37,7 +37,7 @@ python -m lightx2v \
--target_height 480 \
--attention_type flash_attn3 \
--seed 42 \
--sample_neg_promp "画面过曝,模糊,文字,字幕" \
--negative_prompt "画面过曝,模糊,文字,字幕" \
--save_video_path ./output_lightx2v_wan_i2v.mp4 \
--sample_guide_scale 5 \
--sample_shift 5 \
......
......@@ -6,7 +6,7 @@ model_path=""
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
cuda_devices=6
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......@@ -26,22 +26,11 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python ${lightx2v_path}/lightx2v/__main__.py \
python -m lightx2v \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--infer_steps 50 \
--target_video_length 81 \
--target_width 832 \
--target_height 480 \
--attention_type flash_attn2 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ./output_lightx2v_wan_t2v.mp4 \
--sample_guide_scale 6 \
--sample_shift 8 \
# --mm_config '{"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F", "weight_auto_quant": true}' \
# --feature_caching Tea \
# --use_ret_steps \
# --teacache_thresh 0.2
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ./output_lightx2v_wan_t2v.mp4
......@@ -36,7 +36,7 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--target_height 480 \
--attention_type flash_attn2 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--sample_guide_scale 6 \
--sample_shift 8 \
--parallel_attn_type ring \
......@@ -54,7 +54,7 @@ torchrun --nproc_per_node=4 ${lightx2v_path}/lightx2v/__main__.py \
--target_height 480 \
--attention_type flash_attn2 \
--seed 42 \
--sample_neg_promp 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--sample_guide_scale 6 \
--sample_shift 8 \
--parallel_attn_type ulysses \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment