set_config.py 4.7 KB
Newer Older
1
import json
2
import os
PengGao's avatar
PengGao committed
3

helloyongyang's avatar
helloyongyang committed
4
import torch
PengGao's avatar
PengGao committed
5
import torch.distributed as dist
6
from loguru import logger
7
from torch.distributed.tensor.device_mesh import init_device_mesh
8

9
10
11
from lightx2v.utils.input_info import ALL_INPUT_INFO_KEYS
from lightx2v.utils.lockable_dict import LockableDict

12

helloyongyang's avatar
helloyongyang committed
13
14
15
16
17
18
19
20
21
22
23
def get_default_config():
    default_config = {
        "do_mm_calib": False,
        "cpu_offload": False,
        "max_area": False,
        "vae_stride": (4, 8, 8),
        "patch_size": (1, 2, 2),
        "feature_caching": "NoCaching",  # ["NoCaching", "TaylorSeer", "Tea"]
        "teacache_thresh": 0.26,
        "use_ret_steps": False,
        "use_bfloat16": True,
24
        "lora_configs": None,  # List of dicts with 'path' and 'strength' keys
helloyongyang's avatar
helloyongyang committed
25
        "use_prompt_enhancer": False,
26
        "parallel": False,
helloyongyang's avatar
fix bug  
helloyongyang committed
27
28
        "seq_parallel": False,
        "cfg_parallel": False,
29
        "enable_cfg": False,
gushiqiao's avatar
gushiqiao committed
30
        "use_image_encoder": True,
helloyongyang's avatar
helloyongyang committed
31
    }
32
    default_config = LockableDict(default_config)
helloyongyang's avatar
helloyongyang committed
33
34
35
    return default_config


36
def set_config(args):
helloyongyang's avatar
helloyongyang committed
37
    config = get_default_config()
38
    config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS})
39

40
41
42
43
44
    if config.get("config_json", None) is not None:
        logger.info(f"Loading some config from {config['config_json']}")
        with open(config["config_json"], "r") as f:
            config_json = json.load(f)
        config.update(config_json)
45

46
47
    if os.path.exists(os.path.join(config["model_path"], "config.json")):
        with open(os.path.join(config["model_path"], "config.json"), "r") as f:
helloyongyang's avatar
helloyongyang committed
48
49
            model_config = json.load(f)
        config.update(model_config)
50
51
    elif os.path.exists(os.path.join(config["model_path"], "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
        with open(os.path.join(config["model_path"], "low_noise_model", "config.json"), "r") as f:
52
53
            model_config = json.load(f)
        config.update(model_config)
54
55
    elif os.path.exists(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
        with open(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json"), "r") as f:
gushiqiao's avatar
gushiqiao committed
56
57
            model_config = json.load(f)
        config.update(model_config)
58
59
    elif os.path.exists(os.path.join(config["model_path"], "original", "config.json")):
        with open(os.path.join(config["model_path"], "original", "config.json"), "r") as f:
gushiqiao's avatar
gushiqiao committed
60
61
62
            model_config = json.load(f)
        config.update(model_config)
    # load quantized config
gushiqiao's avatar
gushiqiao committed
63
    if config.get("dit_quantized_ckpt", None) is not None:
64
        config_path = os.path.join(config["dit_quantized_ckpt"], "config.json")
gushiqiao's avatar
gushiqiao committed
65
66
67
68
69
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                model_config = json.load(f)
            config.update(model_config)

70
71
72
73
    if config["task"] in ["i2v", "s2v"]:
        if config["target_video_length"] % config["vae_stride"][0] != 1:
            logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
            config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
LiangLiu's avatar
LiangLiu committed
74

Watebear's avatar
Watebear committed
75
76
77
78
    if config["task"] not in ["t2i", "i2i"]:
        config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
        if config["model_cls"] == "seko_talk":
            config["attnmap_frame_num"] += 1
79

80
    return config
81
82
83


def set_parallel_config(config):
84
85
86
    if config["parallel"]:
        cfg_p_size = config["parallel"].get("cfg_p_size", 1)
        seq_p_size = config["parallel"].get("seq_p_size", 1)
87
88
89
        assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
        config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))

90
        if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1:
91
92
            config["seq_parallel"] = True

93
        if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1:
94
            config["cfg_parallel"] = True
helloyongyang's avatar
helloyongyang committed
95
96
97
        # warmup dist
        _a = torch.zeros([1]).to(f"cuda:{dist.get_rank()}")
        dist.all_reduce(_a)
helloyongyang's avatar
helloyongyang committed
98
99
100
101
102


def print_config(config):
    config_to_print = config.copy()
    config_to_print.pop("device_mesh", None)
103
    if config["parallel"]:
helloyongyang's avatar
helloyongyang committed
104
105
106
107
        if dist.get_rank() == 0:
            logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")
    else:
        logger.info(f"config:\n{json.dumps(config_to_print, ensure_ascii=False, indent=4)}")