infer.py 2.13 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import argparse
import torch
import torch.distributed as dist
import json
5

6
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
7
8
from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext
9
from lightx2v.utils.set_config import set_config
helloyongyang's avatar
helloyongyang committed
10
from lightx2v.utils.registry_factory import RUNNER_REGISTER
11

helloyongyang's avatar
helloyongyang committed
12
13
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
14
from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
15
16
from lightx2v.models.runners.graph_runner import GraphRunner

17
from lightx2v.common.ops import *
root's avatar
root committed
18
from loguru import logger
lijiaqi2's avatar
lijiaqi2 committed
19
20


helloyongyang's avatar
helloyongyang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def init_runner(config):
    seed_all(config.seed)

    if config.parallel_attn_type:
        dist.init_process_group(backend="nccl")

    if CHECK_ENABLE_GRAPH_MODE():
        default_runner = RUNNER_REGISTER[config.model_cls](config)
        runner = GraphRunner(default_runner)
    else:
        runner = RUNNER_REGISTER[config.model_cls](config)
    return runner


helloyongyang's avatar
helloyongyang committed
35
36
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
37
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
helloyongyang's avatar
helloyongyang committed
38
39
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
40
    parser.add_argument("--config_json", type=str, required=True)
root's avatar
root committed
41
    parser.add_argument("--enable_cfg", type=bool, default=False)
helloyongyang's avatar
helloyongyang committed
42
    parser.add_argument("--prompt", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
43
    parser.add_argument("--negative_prompt", type=str, default="")
helloyongyang's avatar
helloyongyang committed
44
    parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
45
    parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
helloyongyang's avatar
helloyongyang committed
46
    args = parser.parse_args()
root's avatar
root committed
47
    logger.info(f"args: {args}")
Dongz's avatar
Dongz committed
48

helloyongyang's avatar
helloyongyang committed
49
50
    with ProfilingContext("Total Cost"):
        config = set_config(args)
root's avatar
root committed
51
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
52
        runner = init_runner(config)
53

helloyongyang's avatar
helloyongyang committed
54
        runner.run_pipeline()