infer.py 3.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
import argparse
import torch
import torch.distributed as dist
4
from torch.distributed.device_mesh import init_device_mesh
helloyongyang's avatar
helloyongyang committed
5
import json
6

7
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
8
9
from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext
10
from lightx2v.utils.set_config import set_config
helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.registry_factory import RUNNER_REGISTER
12

helloyongyang's avatar
helloyongyang committed
13
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.models.runners.wan.wan_runner import WanRunner, Wan22MoeRunner
15
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
16
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
wangshankun's avatar
wangshankun committed
17
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner, Wan22MoeAudioRunner
18
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
19
from lightx2v.models.runners.graph_runner import GraphRunner
Watebear's avatar
Watebear committed
20
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
21

22
from lightx2v.common.ops import *
23
from loguru import logger
lijiaqi2's avatar
lijiaqi2 committed
24
25


helloyongyang's avatar
helloyongyang committed
26
27
28
def init_runner(config):
    seed_all(config.seed)

29
    if config.parallel:
30
31
        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")
helloyongyang's avatar
helloyongyang committed
32

33
34
35
36
37
        cfg_p_size = config.parallel.get("cfg_p_size", 1)
        seq_p_size = config.parallel.get("seq_p_size", 1)
        assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
        config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))

helloyongyang's avatar
helloyongyang committed
38
39
40
    if CHECK_ENABLE_GRAPH_MODE():
        default_runner = RUNNER_REGISTER[config.model_cls](config)
        runner = GraphRunner(default_runner)
gushiqiao's avatar
gushiqiao committed
41
        runner.runner.init_modules()
helloyongyang's avatar
helloyongyang committed
42
43
    else:
        runner = RUNNER_REGISTER[config.model_cls](config)
gushiqiao's avatar
gushiqiao committed
44
        runner.init_modules()
helloyongyang's avatar
helloyongyang committed
45
46
47
    return runner


48
def main():
helloyongyang's avatar
helloyongyang committed
49
    parser = argparse.ArgumentParser()
wangshankun's avatar
wangshankun committed
50
    parser.add_argument(
51
52
53
        "--model_cls",
        type=str,
        required=True,
gushiqiao's avatar
Fix  
gushiqiao committed
54
        choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2"],
55
        default="wan2.1",
wangshankun's avatar
wangshankun committed
56
    )
helloyongyang's avatar
helloyongyang committed
57

helloyongyang's avatar
helloyongyang committed
58
59
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
60
    parser.add_argument("--config_json", type=str, required=True)
61
    parser.add_argument("--use_prompt_enhancer", action="store_true")
helloyongyang's avatar
helloyongyang committed
62

wangshankun's avatar
wangshankun committed
63
    parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
helloyongyang's avatar
helloyongyang committed
64
    parser.add_argument("--negative_prompt", type=str, default="")
helloyongyang's avatar
helloyongyang committed
65
66
67
68

    parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
    parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task")

69
    parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
helloyongyang's avatar
helloyongyang committed
70
    args = parser.parse_args()
PengGao's avatar
PengGao committed
71

72
    logger.info(f"args: {args}")
Dongz's avatar
Dongz committed
73

helloyongyang's avatar
helloyongyang committed
74
75
    with ProfilingContext("Total Cost"):
        config = set_config(args)
76
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
77
        runner = init_runner(config)
78

79
        runner.run_pipeline()
Xinchi Huang's avatar
Xinchi Huang committed
80

helloyongyang's avatar
helloyongyang committed
81
82
83
84
85
    # Clean up distributed process group
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("Distributed process group cleaned up")

Xinchi Huang's avatar
Xinchi Huang committed
86
87

if __name__ == "__main__":
88
    main()