set_config.py 5.54 KB
Newer Older
1
import json
2
import os
PengGao's avatar
PengGao committed
3

helloyongyang's avatar
helloyongyang committed
4
import torch
PengGao's avatar
PengGao committed
5
import torch.distributed as dist
6
from loguru import logger
7
from torch.distributed.tensor.device_mesh import init_device_mesh
8

9
10
from lightx2v.utils.input_info import ALL_INPUT_INFO_KEYS
from lightx2v.utils.lockable_dict import LockableDict
11
from lightx2v_platform.base.global_var import AI_DEVICE
12

13

helloyongyang's avatar
helloyongyang committed
14
15
16
17
18
19
20
21
22
23
24
def get_default_config():
    default_config = {
        "do_mm_calib": False,
        "cpu_offload": False,
        "max_area": False,
        "vae_stride": (4, 8, 8),
        "patch_size": (1, 2, 2),
        "feature_caching": "NoCaching",  # ["NoCaching", "TaylorSeer", "Tea"]
        "teacache_thresh": 0.26,
        "use_ret_steps": False,
        "use_bfloat16": True,
25
        "lora_configs": None,  # List of dicts with 'path' and 'strength' keys
helloyongyang's avatar
helloyongyang committed
26
        "use_prompt_enhancer": False,
27
        "parallel": False,
helloyongyang's avatar
fix bug  
helloyongyang committed
28
29
        "seq_parallel": False,
        "cfg_parallel": False,
30
        "enable_cfg": False,
gushiqiao's avatar
gushiqiao committed
31
        "use_image_encoder": True,
helloyongyang's avatar
helloyongyang committed
32
    }
33
    default_config = LockableDict(default_config)
helloyongyang's avatar
helloyongyang committed
34
35
36
    return default_config


37
def set_config(args):
helloyongyang's avatar
helloyongyang committed
38
    config = get_default_config()
39
    config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS})
40

41
42
43
44
45
    if config.get("config_json", None) is not None:
        logger.info(f"Loading some config from {config['config_json']}")
        with open(config["config_json"], "r") as f:
            config_json = json.load(f)
        config.update(config_json)
46

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
47
    if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:  # Special config for hunyuan video 1.5 model folder structure
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
48
49
50
        config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"])  # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v]
        if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")):
            with open(os.path.join(config["transformer_model_path"], "config.json"), "r") as f:
gushiqiao's avatar
gushiqiao committed
51
52
                model_config = json.load(f)
            config.update(model_config)
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    else:
        if os.path.exists(os.path.join(config["model_path"], "config.json")):
            with open(os.path.join(config["model_path"], "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        elif os.path.exists(os.path.join(config["model_path"], "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
            with open(os.path.join(config["model_path"], "low_noise_model", "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        elif os.path.exists(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
            with open(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        elif os.path.exists(os.path.join(config["model_path"], "original", "config.json")):
            with open(os.path.join(config["model_path"], "original", "config.json"), "r") as f:
                model_config = json.load(f)
            config.update(model_config)
        # load quantized config
        if config.get("dit_quantized_ckpt", None) is not None:
            config_path = os.path.join(config["dit_quantized_ckpt"], "config.json")
            if os.path.exists(config_path):
                with open(config_path, "r") as f:
                    model_config = json.load(f)
                config.update(model_config)
gushiqiao's avatar
gushiqiao committed
77

78
79
80
81
    if config["task"] in ["i2v", "s2v"]:
        if config["target_video_length"] % config["vae_stride"][0] != 1:
            logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
            config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
LiangLiu's avatar
LiangLiu committed
82

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
83
    if config["task"] not in ["t2i", "i2i"] and config["model_cls"] not in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
Watebear's avatar
Watebear committed
84
85
86
        config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
        if config["model_cls"] == "seko_talk":
            config["attnmap_frame_num"] += 1
87

88
    return config
89
90
91


def set_parallel_config(config):
92
93
94
    if config["parallel"]:
        cfg_p_size = config["parallel"].get("cfg_p_size", 1)
        seq_p_size = config["parallel"].get("seq_p_size", 1)
95
        assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
96
        config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
97

98
        if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1:
99
100
            config["seq_parallel"] = True

101
        if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1:
102
            config["cfg_parallel"] = True
helloyongyang's avatar
helloyongyang committed
103
        # warmup dist
104
        _a = torch.zeros([1]).to(f"{AI_DEVICE}:{dist.get_rank()}")
helloyongyang's avatar
helloyongyang committed
105
        dist.all_reduce(_a)
helloyongyang's avatar
helloyongyang committed
106
107
108
109
110


def print_config(config):
    config_to_print = config.copy()
    config_to_print.pop("device_mesh", None)
111
    if config["parallel"]:
helloyongyang's avatar
helloyongyang committed
112
113
114
115
        if dist.get_rank() == 0:
            logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
    else:
        logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")