config.py 9.39 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import os
import torch
import torch.distributed as dist
from packaging import version
from dataclasses import dataclass, fields

from torch import distributed as dist

from xfuser.logger import init_logger
import xfuser.envs as envs
# from xfuser.envs import CUDA_VERSION, TORCH_VERSION, PACKAGES_CHECKER
from xfuser.envs import TORCH_VERSION, PACKAGES_CHECKER

logger = init_logger(__name__)

from typing import Union, Optional, List

env_info = PACKAGES_CHECKER.get_packages_info()
HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"]
HAS_FLASH_ATTN = env_info["has_flash_attn"]


def check_packages():
    import diffusers

    if not version.parse(diffusers.__version__) > version.parse("0.30.2"):
        raise RuntimeError(
            "This project requires diffusers version > 0.30.2. Currently, you can not install a correct version of diffusers by pip install."
            "Please install it from source code!"
        )


def check_env():
    # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
    #if CUDA_VERSION < version.parse("11.3"):
    #    raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above")
    if TORCH_VERSION < version.parse("2.2.0"):
        # https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
        raise RuntimeError(
            "CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. "
            "If it is not released yet, please install nightly built PyTorch "
            "with `pip3 install --pre torch torchvision torchaudio --index-url "
            "https://download.pytorch.org/whl/nightly/cu121`"
        )


@dataclass
class ModelConfig:
    model: str
    download_dir: Optional[str] = None
    trust_remote_code: bool = False


@dataclass
class RuntimeConfig:
    warmup_steps: int = 1
    dtype: torch.dtype = torch.float16
    use_cuda_graph: bool = False
    use_parallel_vae: bool = False
    use_profiler: bool = False
    use_torch_compile: bool = False
    use_onediff: bool = False
    use_fp8_t5_encoder: bool = False

    def __post_init__(self):
        check_packages()
        if self.use_cuda_graph:
            check_env()


@dataclass
class FastAttnConfig:
    use_fast_attn: bool = False
    n_step: int = 20
    n_calib: int = 8
    threshold: float = 0.5
    window_size: int = 64
    coco_path: Optional[str] = None
    use_cache: bool = False

    def __post_init__(self):
        assert self.n_calib > 0, "n_calib must be greater than 0"
        assert self.threshold > 0.0, "threshold must be greater than 0"


@dataclass
class DataParallelConfig:
    dp_degree: int = 1
    use_cfg_parallel: bool = False
    world_size: int = 1

    def __post_init__(self):
        assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"

        # set classifier_free_guidance_degree parallel for split batch
        if self.use_cfg_parallel:
            self.cfg_degree = 2
        else:
            self.cfg_degree = 1
        assert self.dp_degree * self.cfg_degree <= self.world_size, (
            "dp_degree * cfg_degree must be less than or equal to "
            "world_size because of classifier free guidance"
        )
        assert (
            self.world_size % (self.dp_degree * self.cfg_degree) == 0
        ), "world_size must be divisible by dp_degree * cfg_degree"


@dataclass
class SequenceParallelConfig:
    ulysses_degree: Optional[int] = None
    ring_degree: Optional[int] = None
    world_size: int = 1

    def __post_init__(self):
        if self.ulysses_degree is None:
            self.ulysses_degree = 1
            logger.info(
                f"Ulysses degree not set, " f"using default value {self.ulysses_degree}"
            )
        if self.ring_degree is None:
            self.ring_degree = 1
            logger.info(
                f"Ring degree not set, " f"using default value {self.ring_degree}"
            )
        self.sp_degree = self.ulysses_degree * self.ring_degree

        if not HAS_LONG_CTX_ATTN and self.sp_degree > 1:
            raise ImportError(
                f"Sequence Parallel kit 'yunchang' not found but "
                f"sp_degree is {self.sp_degree}, please set it "
                f"to 1 or install 'yunchang' to use it"
            )


@dataclass
class TensorParallelConfig:
    tp_degree: int = 1
    split_scheme: Optional[str] = "row"
    world_size: int = 1

    def __post_init__(self):
        assert self.tp_degree >= 1, "tp_degree must greater than 1"
        assert (
            self.tp_degree <= self.world_size
        ), "tp_degree must be less than or equal to world_size"


@dataclass
class PipeFusionParallelConfig:
    pp_degree: int = 1
    num_pipeline_patch: Optional[int] = None
    attn_layer_num_for_pp: Optional[List[int]] = (None,)
    world_size: int = 1

    def __post_init__(self):
        assert (
            self.pp_degree is not None and self.pp_degree >= 1
        ), "pipefusion_degree must be set and greater than 1 to use pipefusion"
        assert (
            self.pp_degree <= self.world_size
        ), "pipefusion_degree must be less than or equal to world_size"
        if self.num_pipeline_patch is None:
            self.num_pipeline_patch = self.pp_degree
            logger.info(
                f"Pipeline patch number not set, "
                f"using default value {self.pp_degree}"
            )
        if self.attn_layer_num_for_pp is not None:
            logger.info(
                f"attn_layer_num_for_pp set, splitting attention layers"
                f"to {self.attn_layer_num_for_pp}"
            )
            assert len(self.attn_layer_num_for_pp) == self.pp_degree, (
                "attn_layer_num_for_pp must have the same "
                "length as pp_degree if not None"
            )
        if self.pp_degree == 1 and self.num_pipeline_patch > 1:
            logger.warning(
                f"Pipefusion degree is 1, pipeline will not be used,"
                f"num_pipeline_patch will be ignored"
            )
            self.num_pipeline_patch = 1


@dataclass
class ParallelConfig:
    dp_config: DataParallelConfig
    sp_config: SequenceParallelConfig
    pp_config: PipeFusionParallelConfig
    tp_config: TensorParallelConfig
    world_size: int = 1 # FIXME: remove this
    worker_cls: str = "xfuser.ray.worker.worker.Worker"

    def __post_init__(self):
        assert self.tp_config is not None, "tp_config must be set"
        assert self.dp_config is not None, "dp_config must be set"
        assert self.sp_config is not None, "sp_config must be set"
        assert self.pp_config is not None, "pp_config must be set"
        parallel_world_size = (
            self.dp_config.dp_degree
            * self.dp_config.cfg_degree
            * self.sp_config.sp_degree
            * self.tp_config.tp_degree
            * self.pp_config.pp_degree
        )
        world_size = self.world_size
        assert parallel_world_size == world_size, (
            f"parallel_world_size {parallel_world_size} "
            f"must be equal to world_size {self.world_size}"
        )
        assert (
            world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
        ), "world_size must be divisible by dp_degree * cfg_degree"
        assert (
            world_size % self.pp_config.pp_degree == 0
        ), "world_size must be divisible by pp_degree"
        assert (
            world_size % self.sp_config.sp_degree == 0
        ), "world_size must be divisible by sp_degree"
        assert (
            world_size % self.tp_config.tp_degree == 0
        ), "world_size must be divisible by tp_degree"
        self.dp_degree = self.dp_config.dp_degree
        self.cfg_degree = self.dp_config.cfg_degree
        self.sp_degree = self.sp_config.sp_degree
        self.pp_degree = self.pp_config.pp_degree
        self.tp_degree = self.tp_config.tp_degree

        self.ulysses_degree = self.sp_config.ulysses_degree
        self.ring_degree = self.sp_config.ring_degree


@dataclass(frozen=True)
class EngineConfig:
    model_config: ModelConfig
    runtime_config: RuntimeConfig
    parallel_config: ParallelConfig
    fast_attn_config: FastAttnConfig

    def __post_init__(self):
        world_size = self.parallel_config.world_size
        if self.fast_attn_config.use_fast_attn:
            assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"

    def to_dict(self):
        """Return the configs as a dictionary, for use in **kwargs."""
        return dict((field.name, getattr(self, field.name)) for field in fields(self))


@dataclass
class InputConfig:
    height: int = 1024
    width: int = 1024
    num_frames: int = 49
    use_resolution_binning: bool = (True,)
    batch_size: Optional[int] = None
    img_file_path: Optional[str] = None
    prompt: Union[str, List[str]] = ""
    negative_prompt: Union[str, List[str]] = ""
    num_inference_steps: int = 20
    max_sequence_length: int = 256
    seed: int = 42
    output_type: str = "pil"

    def __post_init__(self):
        if isinstance(self.prompt, list):
            assert (
                len(self.prompt) == len(self.negative_prompt)
                or len(self.negative_prompt) == 0
            ), "prompts and negative_prompts must have the same quantities"
            self.batch_size = self.batch_size or len(self.prompt)
        else:
            self.batch_size = self.batch_size or 1
        assert self.output_type in [
            "pil",
            "latent",
            "pt",
        ], "output_pil must be either 'pil' or 'latent'"