Unverified Commit 4d9f8201 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Dev hy15 distill (#506)

support hunyuan_video_1.5_distill model class.
parent 0b23aca9
{
"infer_steps": 4,
"transformer_model_name": "480p_t2v",
"fps": 16,
"target_video_length": 81,
"aspect_ratio": "16:9",
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"sample_guide_scale": -1.0,
"enable_cfg": false,
"attn_type": "sage_attn2",
"dit_original_ckpt": "hunyuanvideo-1.5/distill_models/480p_t2v/distill_model.safetensors",
"denoising_step_list": [
1000,
750,
500,
250
]
}
......@@ -5,6 +5,7 @@ import torch.distributed as dist
from loguru import logger
from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
......@@ -51,6 +52,7 @@ def main():
"qwen_image",
"wan2.2_animate",
"hunyuan_video_1.5",
"hunyuan_video_1.5_distill",
],
default="wan2.1",
)
......
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner
from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15SRScheduler
from lightx2v.models.schedulers.hunyuan_video.step_distill.scheduler import HunyuanVideo15StepDistillScheduler
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("hunyuan_video_1.5_distill")
class HunyuanVideo15DistillRunner(HunyuanVideo15Runner):
def __init__(self, config):
super().__init__(config)
def init_scheduler(self):
self.scheduler = HunyuanVideo15StepDistillScheduler(self.config)
if self.sr_version is not None:
self.scheduler_sr = HunyuanVideo15SRScheduler(self.config_sr)
else:
self.scheduler_sr = None
import torch
from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15Scheduler
class HunyuanVideo15StepDistillScheduler(HunyuanVideo15Scheduler):
def __init__(self, config):
super().__init__(config)
self.denoising_step_list = config["denoising_step_list"]
self.infer_steps = len(self.denoising_step_list)
self.num_train_timesteps = 1000
self.sigma_max = 1.0
self.sigma_min = 0.0
def set_timesteps(self, num_inference_steps, device, shift):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
self.sigmas = torch.linspace(sigma_start, self.sigma_min, self.num_train_timesteps + 1)[:-1]
self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
self.timesteps = self.sigmas * self.num_train_timesteps
self.denoising_step_index = [self.num_train_timesteps - x for x in self.denoising_step_list]
self.timesteps = self.timesteps[self.denoising_step_index].to(device)
self.sigmas = self.sigmas[self.denoising_step_index].to("cpu")
def step_post(self):
flow_pred = self.noise_pred.to(torch.float32)
sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
if self.step_index < self.infer_steps - 1:
sigma_n = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n
self.latents = noisy_image_or_video.to(self.latents.dtype)
......@@ -43,9 +43,8 @@ class WanStepDistillScheduler(WanScheduler):
sigma = self.sigmas[self.step_index].item()
noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
if self.step_index < self.infer_steps - 1:
sigma = self.sigmas[self.step_index + 1].item()
noise = torch.randn(noisy_image_or_video.shape, dtype=torch.float32, device=self.device, generator=self.generator)
noisy_image_or_video = self.add_noise(noisy_image_or_video, noise=noise, sigma=self.sigmas[self.step_index + 1].item())
sigma_n = self.sigmas[self.step_index + 1].item()
noisy_image_or_video = noisy_image_or_video + flow_pred * sigma_n
self.latents = noisy_image_or_video.to(self.latents.dtype)
......
......@@ -94,7 +94,7 @@ class LightX2VPipeline:
elif self.model_cls in ["wan2.2"]:
self.vae_stride = (4, 16, 16)
self.num_channels_latents = 48
elif self.model_cls in ["hunyuan_video_1.5"]:
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
self.vae_stride = (4, 16, 16)
self.num_channels_latents = 32
......@@ -174,7 +174,7 @@ class LightX2VPipeline:
self.self_attn_1_type = attn_mode
self.cross_attn_1_type = attn_mode
self.cross_attn_2_type = attn_mode
elif self.model_cls in ["hunyuan_video_1.5"]:
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
self.attn_type = attn_mode
def set_infer_config_json(self, config_json):
......@@ -222,7 +222,7 @@ class LightX2VPipeline:
self.clip_quant_scheme = quant_scheme
self.clip_quantized = image_encoder_quantized
self.clip_quantized_ckpt = image_encoder_quantized_ckpt
elif self.model_cls in ["hunyuan_video_1.5"]:
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
self.qwen25vl_quantized = text_encoder_quantized
self.qwen25vl_quantized_ckpt = text_encoder_quantized_ckpt
self.qwen25vl_quant_scheme = quant_scheme
......@@ -255,7 +255,7 @@ class LightX2VPipeline:
self.t5_cpu_offload = text_encoder_offload
self.clip_encoder_offload = image_encoder_offload
elif self.model_cls in ["hunyuan_video_1.5"]:
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
self.qwen25vl_cpu_offload = text_encoder_offload
self.siglip_cpu_offload = image_encoder_offload
self.byt5_cpu_offload = image_encoder_offload
......
......@@ -43,7 +43,7 @@ def set_config(args):
config_json = json.load(f)
config.update(config_json)
if config["model_cls"] == "hunyuan_video_1.5": # Special config for hunyuan video 1.5 model folder structure
if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: # Special config for hunyuan video 1.5 model folder structure
config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"]) # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v]
if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")):
with open(os.path.join(config["transformer_model_path"], "config.json"), "r") as f:
......@@ -79,7 +79,7 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
if config["task"] not in ["t2i", "i2i"] and config["model_cls"] != "hunyuan_video_1.5":
if config["task"] not in ["t2i", "i2i"] and config["model_cls"] not in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
if config["model_cls"] == "seko_talk":
config["attnmap_frame_num"] += 1
......
......@@ -4,7 +4,7 @@
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=2
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--seed 123 \
--model_cls hunyuan_video_1.5_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json \
--prompt "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." \
--negative_prompt "" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_hunyuan_video_15_t2v_distill.mp4
import argparse
from pathlib import Path
from loguru import logger
from post_multi_servers import get_available_urls, process_tasks_async
def load_prompts_from_folder(folder_path):
"""Load prompts from all files in the specified folder.
Returns:
tuple: (prompts, filenames) where prompts is a list of prompt strings
and filenames is a list of corresponding filenames
"""
prompts = []
filenames = []
folder = Path(folder_path)
if not folder.exists() or not folder.is_dir():
logger.error(f"Prompt folder does not exist or is not a directory: {folder_path}")
return prompts, filenames
# Get all files in the folder and sort them
files = sorted(folder.glob("*"))
files = [f for f in files if f.is_file()]
for file_path in files:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read().strip()
if content: # Only add non-empty prompts
prompts.append(content)
filenames.append(file_path.name)
# logger.info(f"Loaded prompt from {file_path.name}")
except Exception as e:
logger.warning(f"Failed to read file {file_path}: {e}")
return prompts, filenames
def load_prompts_from_file(file_path):
"""Load prompts from a file, one prompt per line.
Returns:
list: prompts, where each element is a prompt string
"""
prompts = []
file = Path(file_path)
if not file.exists() or not file.is_file():
logger.error(f"Prompt file does not exist or is not a file: {file_path}")
return prompts
try:
with open(file, "r", encoding="utf-8") as f:
for line in f:
prompt = line.strip()
if prompt: # Only add non-empty prompts
prompts.append(prompt)
except Exception as e:
logger.error(f"Failed to read prompt file {file_path}: {e}")
return prompts
if __name__ == "__main__":
urls = ["http://localhost:8000", "http://localhost:8001"]
prompts = [
"A cat walks on the grass, realistic style.",
"A person is riding a bike. Realistic, Natural lighting, Casual.",
"A car turns a corner. Realistic, Natural lighting, Casual.",
"An astronaut is flying in space, Van Gogh style. Dark, Mysterious.",
"A beautiful coastal beach in spring, waves gently lapping on the sand, the camera movement is Zoom In. Realistic, Natural lighting, Peaceful.",
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
]
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
parser = argparse.ArgumentParser(description="Post prompts to multiple T2V servers")
parser.add_argument("--prompt-folder", type=str, default=None, help="Folder containing prompt files. If not specified, use default prompts.")
parser.add_argument("--prompt-file", type=str, default=None, help="File containing prompts, one prompt per line. Cannot be used together with --prompt-folder.")
parser.add_argument("--save-folder", type=str, default="./", help="Folder to save output videos. Default is current directory.")
args = parser.parse_args()
messages = []
for i, prompt in enumerate(prompts):
messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": "", "save_result_path": f"./output_lightx2v_wan_t2v_{i + 1}.mp4"})
# Check that --prompt-folder and --prompt-file are not used together
if args.prompt_folder and args.prompt_file:
logger.error("Cannot use --prompt-folder and --prompt-file together. Please choose one.")
exit(1)
# Generate URLs from IPs (each IP has 8 ports: 8000-8007)
ips = ["localhost"]
urls = [f"http://{ip}:{port}" for ip in ips for port in range(8000, 8008)]
# urls = ["http://localhost:8007"]
logger.info(f"urls: {urls}")
......@@ -24,6 +87,75 @@ if __name__ == "__main__":
if not available_urls:
exit(1)
logger.info(f"Total {len(available_urls)} available servers.")
# Load prompts from folder, file, or use default prompts
prompt_filenames = None
if args.prompt_folder:
logger.info(f"Loading prompts from folder: {args.prompt_folder}")
prompts, prompt_filenames = load_prompts_from_folder(args.prompt_folder)
if not prompts:
logger.error("No valid prompts loaded from folder.")
exit(1)
elif args.prompt_file:
logger.info(f"Loading prompts from file: {args.prompt_file}")
prompts = load_prompts_from_file(args.prompt_file)
if not prompts:
logger.error("No valid prompts loaded from file.")
exit(1)
else:
logger.info("Using default prompts")
prompts = [
"A cat walks on the grass, realistic style.",
"A person is riding a bike. Realistic, Natural lighting, Casual.",
"A car turns a corner. Realistic, Natural lighting, Casual.",
"An astronaut is flying in space, Van Gogh style. Dark, Mysterious.",
"A beautiful coastal beach in spring, waves gently lapping on the sand, the camera movement is Zoom In. Realistic, Natural lighting, Peaceful.",
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
]
negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
# Prepare save folder
save_folder = Path(args.save_folder)
save_folder.mkdir(parents=True, exist_ok=True)
messages = []
total_count = len(prompts)
skipped_count = 0
for i, prompt in enumerate(prompts):
# Generate output filename
if prompt_filenames:
# Use prompt filename, replace extension with .mp4
filename = Path(prompt_filenames[i]).stem + ".mp4"
else:
# Use default naming
filename = f"output_lightx2v_wan_t2v_{i + 1}.mp4"
save_path = save_folder / filename
# Skip if file already exists (only when using prompt_filenames)
if prompt_filenames and save_path.exists():
logger.info(f"Skipping {filename} - file already exists")
skipped_count += 1
continue
messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": "", "save_result_path": str(save_path)})
# Log statistics
to_process_count = len(messages)
logger.info("=" * 80)
logger.info("Task Statistics:")
logger.info(f" Total prompts: {total_count}")
logger.info(f" Skipped (already exists): {skipped_count}")
logger.info(f" To process: {to_process_count}")
logger.info("=" * 80)
if to_process_count == 0:
logger.info("No tasks to process. All files already exist.")
exit(0)
# Process tasks asynchronously
success = process_tasks_async(messages, available_urls, show_progress=True)
......
#!/bin/bash
# set path and first
lightx2v_path=/path/to/Lightx2v
model_path=/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
......@@ -12,10 +12,10 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.server \
--model_cls seko_talk \
--task s2v \
--model_cls hunyuan_video_1.5_distill \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_05_offload_fp8_4090.json \
--config_json ${lightx2v_path}/configs/hunyuan_video_15/hunyuan_video_t2v_480p_distill.json \
--port 8000
echo "Service stopped"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment