infer.py 5.09 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import argparse
2

Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
3
import torch
PengGao's avatar
PengGao committed
4
5
import torch.distributed as dist
from loguru import logger
Gu Shiqiao's avatar
Gu Shiqiao committed
6
7
8
9
10

try:
    from torch.distributed import ProcessGroupNCCL
except ImportError:
    ProcessGroupNCCL = None
11

PengGao's avatar
PengGao committed
12
from lightx2v.common.ops import *
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
13
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner  # noqa: F401
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
14
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner  # noqa: F401
15
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner  # noqa: F401
16
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner  # noqa: F401
sandy's avatar
sandy committed
17
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner  # noqa: F401
PengGao's avatar
PengGao committed
18
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner  # noqa: F401
19
from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner  # noqa: F401
PengGao's avatar
PengGao committed
20
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner  # noqa: F401
21
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner  # noqa: F401
gushiqiao's avatar
gushiqiao committed
22
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner  # noqa: F401
PengGao's avatar
PengGao committed
23
from lightx2v.utils.envs import *
24
from lightx2v.utils.input_info import set_input_info
25
from lightx2v.utils.profiler import *
PengGao's avatar
PengGao committed
26
from lightx2v.utils.registry_factory import RUNNER_REGISTER
helloyongyang's avatar
helloyongyang committed
27
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
PengGao's avatar
PengGao committed
28
from lightx2v.utils.utils import seed_all
lijiaqi2's avatar
lijiaqi2 committed
29
30


helloyongyang's avatar
helloyongyang committed
31
def init_runner(config):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
32
    torch.set_grad_enabled(False)
33
    runner = RUNNER_REGISTER[config["model_cls"]](config)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
34
    runner.init_modules()
helloyongyang's avatar
helloyongyang committed
35
36
37
    return runner


38
def main():
helloyongyang's avatar
helloyongyang committed
39
    parser = argparse.ArgumentParser()
40
    parser.add_argument("--seed", type=int, default=42, help="The seed for random generator")
wangshankun's avatar
wangshankun committed
41
    parser.add_argument(
42
43
44
        "--model_cls",
        type=str,
        required=True,
45
46
47
        choices=[
            "wan2.1",
            "wan2.1_distill",
gushiqiao's avatar
gushiqiao committed
48
            "wan2.1_vace",
49
            "wan2.1_sf",
50
            "wan2.1_sf_mtxg2",
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
51
            "seko_talk",
52
            "wan2.2_moe",
helloyongyang's avatar
fix ci  
helloyongyang committed
53
            "wan2.2",
54
55
56
            "wan2.2_moe_audio",
            "wan2.2_audio",
            "wan2.2_moe_distill",
57
            "qwen_image",
58
            "wan2.2_animate",
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
59
            "hunyuan_video_1.5",
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
60
            "hunyuan_video_1.5_distill",
61
        ],
62
        default="wan2.1",
wangshankun's avatar
wangshankun committed
63
    )
helloyongyang's avatar
helloyongyang committed
64

65
    parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v"], default="t2v")
helloyongyang's avatar
helloyongyang committed
66
    parser.add_argument("--model_path", type=str, required=True)
67
    parser.add_argument("--sf_model_path", type=str, required=False)
helloyongyang's avatar
helloyongyang committed
68
    parser.add_argument("--config_json", type=str, required=True)
69
    parser.add_argument("--use_prompt_enhancer", action="store_true")
helloyongyang's avatar
helloyongyang committed
70

wangshankun's avatar
wangshankun committed
71
    parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
helloyongyang's avatar
helloyongyang committed
72
    parser.add_argument("--negative_prompt", type=str, default="")
helloyongyang's avatar
helloyongyang committed
73
74

    parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
gushiqiao's avatar
gushiqiao committed
75
    parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
sandy's avatar
sandy committed
76
    parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task")
helloyongyang's avatar
helloyongyang committed
77

sandy's avatar
sandy committed
78
    # [Warning] For vace task, need refactor.
gushiqiao's avatar
gushiqiao committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    parser.add_argument(
        "--src_ref_images",
        type=str,
        default=None,
        help="The file list of the source reference images. Separated by ','. Default None.",
    )
    parser.add_argument(
        "--src_video",
        type=str,
        default=None,
        help="The file of the source video. Default None.",
    )
    parser.add_argument(
        "--src_mask",
        type=str,
        default=None,
        help="The file of the source mask. Default None.",
    )

98
99
    parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
    parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
helloyongyang's avatar
helloyongyang committed
100
    args = parser.parse_args()
PengGao's avatar
PengGao committed
101

102
103
    seed_all(args.seed)

helloyongyang's avatar
helloyongyang committed
104
105
106
    # set config
    config = set_config(args)

107
    if config["parallel"]:
Kane's avatar
Kane committed
108
109
        run_device = config.get("run_device", "cuda")
        if "cuda" in run_device:
110
111
112
            pg_options = ProcessGroupNCCL.Options()
            pg_options.is_high_priority_stream = True
            dist.init_process_group(backend="nccl", pg_options=pg_options)
Kane's avatar
Kane committed
113
114
115
116
            torch.cuda.set_device(dist.get_rank())
        elif "mlu" in run_device:
            dist.init_process_group(backend="cncl")
            torch.mlu.set_device(dist.get_rank())
helloyongyang's avatar
helloyongyang committed
117
118
        set_parallel_config(config)

helloyongyang's avatar
helloyongyang committed
119
120
    print_config(config)

121
    with ProfilingContext4DebugL1("Total Cost"):
helloyongyang's avatar
helloyongyang committed
122
        runner = init_runner(config)
123
124
        input_info = set_input_info(args)
        runner.run_pipeline(input_info)
Xinchi Huang's avatar
Xinchi Huang committed
125

helloyongyang's avatar
helloyongyang committed
126
127
128
129
130
    # Clean up distributed process group
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("Distributed process group cleaned up")

Xinchi Huang's avatar
Xinchi Huang committed
131
132

if __name__ == "__main__":
133
    main()