__main__.py 1.89 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
import argparse
import torch
import torch.distributed as dist
import json
5

6
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
7
8
from lightx2v.utils.utils import seed_all
from lightx2v.utils.profiler import ProfilingContext
9
from lightx2v.utils.set_config import set_config
helloyongyang's avatar
helloyongyang committed
10
from lightx2v.utils.registry_factory import RUNNER_REGISTER
11

helloyongyang's avatar
helloyongyang committed
12
13
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
14
15
from lightx2v.models.runners.graph_runner import GraphRunner

16
from lightx2v.common.ops import *
lijiaqi2's avatar
lijiaqi2 committed
17
18


helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
24
25
    parser.add_argument("--image_path", type=str, default=None, help="The path to input image file or path for image-to-video (i2v) task")
    parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
helloyongyang's avatar
helloyongyang committed
26
    parser.add_argument("--prompt", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
27
28
    parser.add_argument("--negative_prompt", type=str, default="")
    parser.add_argument("--config_json", type=str, required=True)
helloyongyang's avatar
helloyongyang committed
29
30
    args = parser.parse_args()
    print(f"args: {args}")
Dongz's avatar
Dongz committed
31

helloyongyang's avatar
helloyongyang committed
32
33
34
    with ProfilingContext("Total Cost"):
        config = set_config(args)
        print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
35

helloyongyang's avatar
helloyongyang committed
36
        seed_all(config.seed)
helloyongyang's avatar
helloyongyang committed
37

helloyongyang's avatar
helloyongyang committed
38
39
        if config.parallel_attn_type:
            dist.init_process_group(backend="nccl")
helloyongyang's avatar
helloyongyang committed
40

helloyongyang's avatar
helloyongyang committed
41
42
43
44
45
46
        if CHECK_ENABLE_GRAPH_MODE():
            default_runner = RUNNER_REGISTER[config.model_cls](config)
            runner = GraphRunner(default_runner)
        else:
            runner = RUNNER_REGISTER[config.model_cls](config)
        runner.run_pipeline()