set_config.py 3.38 KB
Newer Older
1
import json
2
import os
PengGao's avatar
PengGao committed
3
4

import torch.distributed as dist
5
from easydict import EasyDict
6
from loguru import logger
7
from torch.distributed.tensor.device_mesh import init_device_mesh
8
9


helloyongyang's avatar
helloyongyang committed
10
11
12
13
14
15
16
17
18
19
20
def get_default_config():
    default_config = {
        "do_mm_calib": False,
        "cpu_offload": False,
        "max_area": False,
        "vae_stride": (4, 8, 8),
        "patch_size": (1, 2, 2),
        "feature_caching": "NoCaching",  # ["NoCaching", "TaylorSeer", "Tea"]
        "teacache_thresh": 0.26,
        "use_ret_steps": False,
        "use_bfloat16": True,
21
        "lora_configs": None,  # List of dicts with 'path' and 'strength' keys
22
        "mm_config": {},
helloyongyang's avatar
helloyongyang committed
23
        "use_prompt_enhancer": False,
24
        "parallel": False,
helloyongyang's avatar
fix bug  
helloyongyang committed
25
26
        "seq_parallel": False,
        "cfg_parallel": False,
27
        "enable_cfg": False,
gushiqiao's avatar
gushiqiao committed
28
        "use_image_encoder": True,
helloyongyang's avatar
helloyongyang committed
29
30
31
32
    }
    return default_config


33
def set_config(args):
helloyongyang's avatar
helloyongyang committed
34
35
    config = get_default_config()
    config.update({k: v for k, v in vars(args).items()})
36
37
    config = EasyDict(config)

helloyongyang's avatar
helloyongyang committed
38
    with open(config.config_json, "r") as f:
helloyongyang's avatar
helloyongyang committed
39
40
        config_json = json.load(f)
    config.update(config_json)
41

helloyongyang's avatar
helloyongyang committed
42
43
    if os.path.exists(os.path.join(config.model_path, "config.json")):
        with open(os.path.join(config.model_path, "config.json"), "r") as f:
helloyongyang's avatar
helloyongyang committed
44
45
46
47
            model_config = json.load(f)
        config.update(model_config)
    elif os.path.exists(os.path.join(config.model_path, "low_noise_model", "config.json")):  # 需要一个更优雅的update方法
        with open(os.path.join(config.model_path, "low_noise_model", "config.json"), "r") as f:
48
49
            model_config = json.load(f)
        config.update(model_config)
gushiqiao's avatar
gushiqiao committed
50
51
52
53
54
    elif os.path.exists(os.path.join(config.model_path, "original", "config.json")):
        with open(os.path.join(config.model_path, "original", "config.json"), "r") as f:
            model_config = json.load(f)
        config.update(model_config)
    # load quantized config
gushiqiao's avatar
gushiqiao committed
55
56
57
58
59
60
61
    if config.get("dit_quantized_ckpt", None) is not None:
        config_path = os.path.join(config.dit_quantized_ckpt, "config.json")
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                model_config = json.load(f)
            config.update(model_config)

Watebear's avatar
Watebear committed
62
63
64
65
    if config.task == "i2v":
        if config.target_video_length % config.vae_stride[0] != 1:
            logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
            config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
66

67
    return config
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


def set_parallel_config(config):
    if config.parallel:
        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")

        cfg_p_size = config.parallel.get("cfg_p_size", 1)
        seq_p_size = config.parallel.get("seq_p_size", 1)
        assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
        config["device_mesh"] = init_device_mesh("cuda", (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))

        if config.parallel and config.parallel.get("seq_p_size", False) and config.parallel.seq_p_size > 1:
            config["seq_parallel"] = True

        if config.get("enable_cfg", False) and config.parallel and config.parallel.get("cfg_p_size", False) and config.parallel.cfg_p_size > 1:
            config["cfg_parallel"] = True