set_config.py 5.47 KB
Newer Older
1
import json
2
import os
PengGao's avatar
PengGao committed
3

helloyongyang's avatar
helloyongyang committed
4
import torch
PengGao's avatar
PengGao committed
5
import torch.distributed as dist
6
from loguru import logger
7
from torch.distributed.tensor.device_mesh import init_device_mesh
8

9
10
11
from lightx2v.utils.input_info import ALL_INPUT_INFO_KEYS
from lightx2v.utils.lockable_dict import LockableDict

12

helloyongyang's avatar
helloyongyang committed
13
14
15
16
17
18
19
20
21
22
23
def get_default_config():
    default_config = {
        "do_mm_calib": False,
        "cpu_offload": False,
        "max_area": False,
        "vae_stride": (4, 8, 8),
        "patch_size": (1, 2, 2),
        "feature_caching": "NoCaching",  # ["NoCaching", "TaylorSeer", "Tea"]
        "teacache_thresh": 0.26,
        "use_ret_steps": False,
        "use_bfloat16": True,
24
        "lora_configs": None,  # List of dicts with 'path' and 'strength' keys
helloyongyang's avatar
helloyongyang committed
25
        "use_prompt_enhancer": False,
26
        "parallel": False,
helloyongyang's avatar
fix bug  
helloyongyang committed
27
28
        "seq_parallel": False,
        "cfg_parallel": False,
29
        "enable_cfg": False,
gushiqiao's avatar
gushiqiao committed
30
        "use_image_encoder": True,
helloyongyang's avatar
helloyongyang committed
31
    }
32
    default_config = LockableDict(default_config)
helloyongyang's avatar
helloyongyang committed
33
34
35
    return default_config


36
def set_config(args):
helloyongyang's avatar
helloyongyang committed
37
    config = get_default_config()
38
    config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS})
39

40
41
42
43
44
    if config.get("config_json", None) is not None:
        logger.info(f"Loading some config from {config['config_json']}")
        with open(config["config_json"], "r") as f:
            config_json = json.load(f)
        config.update(config_json)
45

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
46
    if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:  # Special config for hunyuan video 1.5 model folder structure
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
47
48
49
        config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"])  # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v]
        if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")):
            with open(os.path.join(config["transformer_model_path"], "config.json"), "r") as f:
gushiqiao's avatar
gushiqiao committed
50
51
                model_config = json.load(f)
            config.update(model_config)
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    else:
        if os.path.exists(os.path.join(config["model_path"], "config.json")):
            with open(os.path.join(config["model_path"], "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        elif os.path.exists(os.path.join(config["model_path"], "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
            with open(os.path.join(config["model_path"], "low_noise_model", "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        elif os.path.exists(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
            with open(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        elif os.path.exists(os.path.join(config["model_path"], "original", "config.json")):
            with open(os.path.join(config["model_path"], "original", "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        # load quantized config
        if config.get("dit_quantized_ckpt", None) is not None:
            config_path = os.path.join(config["dit_quantized_ckpt"], "config.json")
            if os.path.exists(config_path):
                with open(config_path, "r") as f:
                    model_config = json.load(f)
                config.update(model_config)
gushiqiao's avatar
gushiqiao committed
76

77
78
79
80
    if config["task"] in ["i2v", "s2v"]:
        if config["target_video_length"] % config["vae_stride"][0] != 1:
            logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
            config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
LiangLiu's avatar
LiangLiu committed
81

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
82
    if config["task"] not in ["t2i", "i2i"] and config["model_cls"] not in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
Watebear's avatar
Watebear committed
83
84
85
        config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
        if config["model_cls"] == "seko_talk":
            config["attnmap_frame_num"] += 1
86

87
    return config
88
89
90


def set_parallel_config(config):
91
92
93
    if config["parallel"]:
        cfg_p_size = config["parallel"].get("cfg_p_size", 1)
        seq_p_size = config["parallel"].get("seq_p_size", 1)
94
95
96
        assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
        config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))

97
        if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1:
98
99
            config["seq_parallel"] = True

100
        if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1:
101
            config["cfg_parallel"] = True
helloyongyang's avatar
helloyongyang committed
102
103
104
        # warmup dist
        _a = torch.zeros([1]).to(f"cuda:{dist.get_rank()}")
        dist.all_reduce(_a)
helloyongyang's avatar
helloyongyang committed
105
106
107
108
109


def print_config(config):
    config_to_print = config.copy()
    config_to_print.pop("device_mesh", None)
110
    if config["parallel"]:
helloyongyang's avatar
helloyongyang committed
111
112
113
114
        if dist.get_rank() == 0:
            logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
    else:
        logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")