infer.py 3.82 KB
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
3
4
5
import argparse
import torch
import torch.distributed as dist
import json
6

7
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
8
9
from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext
10
from lightx2v.utils.set_config import set_config
helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.registry_factory import RUNNER_REGISTER
12

helloyongyang's avatar
helloyongyang committed
13
14
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
15
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
16
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
wangshankun's avatar
wangshankun committed
17
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner
18
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
19
from lightx2v.models.runners.graph_runner import GraphRunner
Watebear's avatar
Watebear committed
20
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
21

22
from lightx2v.common.ops import *
23
from loguru import logger
lijiaqi2's avatar
lijiaqi2 committed
24
25


helloyongyang's avatar
helloyongyang committed
26
27
28
29
def init_runner(config):
    seed_all(config.seed)

    if config.parallel_attn_type:
30
31
        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")
helloyongyang's avatar
helloyongyang committed
32
33
34
35

    if CHECK_ENABLE_GRAPH_MODE():
        default_runner = RUNNER_REGISTER[config.model_cls](config)
        runner = GraphRunner(default_runner)
gushiqiao's avatar
gushiqiao committed
36
        runner.runner.init_modules()
helloyongyang's avatar
helloyongyang committed
37
38
    else:
        runner = RUNNER_REGISTER[config.model_cls](config)
gushiqiao's avatar
gushiqiao committed
39
        runner.init_modules()
helloyongyang's avatar
helloyongyang committed
40
41
42
    return runner


Xinchi Huang's avatar
Xinchi Huang committed
43
async def main():
helloyongyang's avatar
helloyongyang committed
44
    parser = argparse.ArgumentParser()
wangshankun's avatar
wangshankun committed
45
46
47
    parser.add_argument(
        "--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
    )
helloyongyang's avatar
helloyongyang committed
48
49
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
50
    parser.add_argument("--config_json", type=str, required=True)
51
    parser.add_argument("--use_prompt_enhancer", action="store_true")
helloyongyang's avatar
helloyongyang committed
52

wangshankun's avatar
wangshankun committed
53
    parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
helloyongyang's avatar
helloyongyang committed
54
    parser.add_argument("--negative_prompt", type=str, default="")
wangshankun's avatar
wangshankun committed
55
    parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
56
    parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
wangshankun's avatar
wangshankun committed
57
58
    parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
    parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
helloyongyang's avatar
helloyongyang committed
59
    parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
60
    parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
helloyongyang's avatar
helloyongyang committed
61
    args = parser.parse_args()
PengGao's avatar
PengGao committed
62
63
64
65
66
67
68
69
70
71
72
73
74

    if args.prompt_path:
        try:
            with open(args.prompt_path, "r", encoding="utf-8") as f:
                args.prompt = f.read().strip()
            logger.info(f"从文件 {args.prompt_path} 读取到prompt: {args.prompt}")
        except FileNotFoundError:
            logger.error(f"找不到prompt文件: {args.prompt_path}")
            raise
        except Exception as e:
            logger.error(f"读取prompt文件时出错: {e}")
            raise

75
76
77
78
79
    if args.lora_path:
        args.lora_configs = [{"path": args.lora_path, "strength": args.lora_strength}]
        delattr(args, "lora_path")
        delattr(args, "lora_strength")

80
    logger.info(f"args: {args}")
Dongz's avatar
Dongz committed
81

helloyongyang's avatar
helloyongyang committed
82
83
    with ProfilingContext("Total Cost"):
        config = set_config(args)
84
        config["mode"] = "infer"
85
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
86
        runner = init_runner(config)
87

Xinchi Huang's avatar
Xinchi Huang committed
88
89
90
91
92
        await runner.run_pipeline()


if __name__ == "__main__":
    asyncio.run(main())