infer.py 3.52 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import argparse
2

PengGao's avatar
PengGao committed
3
4
import torch.distributed as dist
from loguru import logger
5

PengGao's avatar
PengGao committed
6
from lightx2v.common.ops import *
PengGao's avatar
PengGao committed
7
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner  # noqa: F401
PengGao's avatar
PengGao committed
8
from lightx2v.models.runners.graph_runner import GraphRunner
PengGao's avatar
PengGao committed
9
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner  # noqa: F401
10
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner  # noqa: F401
PengGao's avatar
PengGao committed
11
12
13
14
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner  # noqa: F401
PengGao's avatar
PengGao committed
15
16
17
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
helloyongyang's avatar
helloyongyang committed
18
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
PengGao's avatar
PengGao committed
19
from lightx2v.utils.utils import seed_all
lijiaqi2's avatar
lijiaqi2 committed
20
21


helloyongyang's avatar
helloyongyang committed
22
23
24
25
26
def init_runner(config):
    seed_all(config.seed)

    if CHECK_ENABLE_GRAPH_MODE():
        default_runner = RUNNER_REGISTER[config.model_cls](config)
helloyongyang's avatar
helloyongyang committed
27
        default_runner.init_modules()
helloyongyang's avatar
helloyongyang committed
28
29
30
        runner = GraphRunner(default_runner)
    else:
        runner = RUNNER_REGISTER[config.model_cls](config)
gushiqiao's avatar
gushiqiao committed
31
        runner.init_modules()
helloyongyang's avatar
helloyongyang committed
32
33
34
    return runner


35
def main():
helloyongyang's avatar
helloyongyang committed
36
    parser = argparse.ArgumentParser()
wangshankun's avatar
wangshankun committed
37
    parser.add_argument(
38
39
40
        "--model_cls",
        type=str,
        required=True,
41
        choices=[
42
            
43
            "wan2.1",
44
           
45
            "hunyuan",
46
           
47
            "wan2.1_distill",
48
           
49
            "wan2.1_causvid",
50
           
51
            "wan2.1_skyreels_v2_df",
52
           
53
            "cogvideox",
54
           
55
            "wan2.1_audio",
56
           
57
            "wan2.2_moe",
58
           
helloyongyang's avatar
fix ci  
helloyongyang committed
59
            "wan2.2",
60
            "wan2.2_moe_audio",
61
           
62
            "wan2.2_audio",
63
64
           
           
65
            "wan2.2_moe_distill",
66
        ,
67
        ],
68
        default="wan2.1",
wangshankun's avatar
wangshankun committed
69
    )
helloyongyang's avatar
helloyongyang committed
70

helloyongyang's avatar
helloyongyang committed
71
72
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
73
    parser.add_argument("--config_json", type=str, required=True)
74
    parser.add_argument("--use_prompt_enhancer", action="store_true")
helloyongyang's avatar
helloyongyang committed
75

wangshankun's avatar
wangshankun committed
76
    parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
helloyongyang's avatar
helloyongyang committed
77
    parser.add_argument("--negative_prompt", type=str, default="")
helloyongyang's avatar
helloyongyang committed
78
79
80
81

    parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
    parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task")

82
    parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
helloyongyang's avatar
helloyongyang committed
83
    args = parser.parse_args()
PengGao's avatar
PengGao committed
84

helloyongyang's avatar
helloyongyang committed
85
86
87
    # set config
    config = set_config(args)

helloyongyang's avatar
fix bug  
helloyongyang committed
88
    if config.parallel:
helloyongyang's avatar
helloyongyang committed
89
90
91
92
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(dist.get_rank())
        set_parallel_config(config)

helloyongyang's avatar
helloyongyang committed
93
94
    print_config(config)

helloyongyang's avatar
helloyongyang committed
95
    with ProfilingContext("Total Cost"):
helloyongyang's avatar
helloyongyang committed
96
        runner = init_runner(config)
97
        runner.run_pipeline()
Xinchi Huang's avatar
Xinchi Huang committed
98

helloyongyang's avatar
helloyongyang committed
99
100
101
102
103
    # Clean up distributed process group
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("Distributed process group cleaned up")

Xinchi Huang's avatar
Xinchi Huang committed
104
105

if __name__ == "__main__":
106
    main()