"mmdet/models/dense_heads/maskformer_head.py" did not exist on "2f6baaee5db2641711a85f745e9e0a57a4049a1f"
infer.py 7.38 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import argparse
import os

import torch
import torch.distributed as dist
from loguru import logger

from lightx2v.common.ops import *
from lightx2v.models.runners.bagel.bagel_runner import BagelRunner  # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner  # noqa: F401
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner  # noqa: F401
from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner  # noqa: F401
from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner  # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner  # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner  # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import Wan22MoeVaceRunner, WanVaceRunner  # noqa: F401
from lightx2v.models.runners.worldplay.worldplay_ar_runner import WorldPlayARRunner  # noqa: F401
from lightx2v.models.runners.worldplay.worldplay_bi_runner import WorldPlayBIRunner  # noqa: F401
from lightx2v.models.runners.worldplay.worldplay_distill_runner import WorldPlayDistillRunner  # noqa: F401
from lightx2v.models.runners.z_image.z_image_runner import ZImageRunner  # noqa: F401
from lightx2v.utils.envs import *
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all, validate_config_paths, validate_task_arguments
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


def init_runner(config):
    torch.set_grad_enabled(False)
    runner = RUNNER_REGISTER[config["model_cls"]](config)
    runner.init_modules()
    return runner


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42, help="The seed for random generator")
    parser.add_argument(
        "--model_cls",
        type=str,
        required=True,
        choices=[
            "wan2.1",
            "wan2.1_distill",
            "wan2.1_mean_flow_distill",
            "wan2.1_vace",
            "wan2.1_sf",
            "wan2.1_sf_mtxg2",
            "seko_talk",
            "wan2.2_moe",
            "wan2.2",
            "wan2.2_moe_audio",
            "wan2.2_audio",
            "wan2.2_moe_distill",
            "wan2.2_moe_vace",
            "qwen_image",
            "longcat_image",
            "wan2.2_animate",
            "hunyuan_video_1.5",
            "hunyuan_video_1.5_distill",
            "worldplay_distill",
            "worldplay_ar",
            "worldplay_bi",
            "z_image",
            "ltx2",
            "bagel",
        ],
        default="wan2.1",
    )

    parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v", "rs2v", "t2av", "i2av"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--sf_model_path", type=str, required=False)
    parser.add_argument("--config_json", type=str, required=True)
    parser.add_argument("--use_prompt_enhancer", action="store_true")

    parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
    parser.add_argument("--negative_prompt", type=str, default="")

    parser.add_argument(
        "--image_path",
        type=str,
        default="",
        help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'",
    )
    parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
    parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task")
    parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task")
    # [Warning] For vace task, need refactor.
    parser.add_argument(
        "--src_ref_images",
        type=str,
        default=None,
        help="The file list of the source reference images. Separated by ','. Default None.",
    )
    parser.add_argument(
        "--src_video",
        type=str,
        default=None,
        help="The file of the source video. Default None.",
    )
    parser.add_argument(
        "--src_mask",
        type=str,
        default=None,
        help="The file of the source mask. Default None.",
    )
    parser.add_argument(
        "--src_pose_path",
        type=str,
        default=None,
        help="The file of the source pose. Default None.",
    )
    parser.add_argument(
        "--src_face_path",
        type=str,
        default=None,
        help="The file of the source face. Default None.",
    )
    parser.add_argument(
        "--src_bg_path",
        type=str,
        default=None,
        help="The file of the source background. Default None.",
    )
    parser.add_argument(
        "--src_mask_path",
        type=str,
        default=None,
        help="The file of the source mask. Default None.",
    )
    parser.add_argument(
        "--pose",
        type=str,
        default=None,
        help="Pose string (e.g., 'w-3, right-0.5') or JSON file path for WorldPlay models.",
    )
    parser.add_argument(
        "--action_ckpt",
        type=str,
        default=None,
        help="Path to action model checkpoint for WorldPlay models.",
    )
    parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
    parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
    parser.add_argument("--target_shape", nargs="+", default=[], help="Set return video or image shape")
    parser.add_argument("--aspect_ratio", type=str, default="")

    args = parser.parse_args()
    validate_task_arguments(args)

    seed_all(args.seed)

    # set config
    config = set_config(args)
    # init input_info
    input_info = init_empty_input_info(args.task)

    if config["parallel"]:
        platform_device = PLATFORM_DEVICE_REGISTER.get(os.getenv("PLATFORM", "cuda"), None)
        platform_device.init_parallel_env()
        set_parallel_config(config)

    print_config(config)

    validate_config_paths(config)

    with ProfilingContext4DebugL1("Total Cost"):
        # init runner
        runner = init_runner(config)
        # start to infer
        data = args.__dict__
        update_input_info_from_dict(input_info, data)
        runner.run_pipeline(input_info)

    # Clean up distributed process group
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("Distributed process group cleaned up")


if __name__ == "__main__":
    main()