args.py 13.3 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import sys
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional, List, Tuple, Union

import torch
import torch.distributed

from xfuser.logger import init_logger
from xfuser.core.distributed import init_distributed_environment
from xfuser.config.config import (
    EngineConfig,
    FastAttnConfig,
    ParallelConfig,
    TensorParallelConfig,
    PipeFusionParallelConfig,
    SequenceParallelConfig,
    DataParallelConfig,
    ModelConfig,
    InputConfig,
    RuntimeConfig,
)

logger = init_logger(__name__)


class FlexibleArgumentParser(argparse.ArgumentParser):
    """ArgumentParser that allows both underscore and dash in names."""

    def parse_args(self, args=None, namespace=None):
        if args is None:
            args = sys.argv[1:]

        # Convert underscores to dashes and vice versa in argument names
        processed_args = []
        for arg in args:
            if arg.startswith("--"):
                if "=" in arg:
                    key, value = arg.split("=", 1)
                    key = "--" + key[len("--") :].replace("-", "_")
                    processed_args.append(f"{key}={value}")
                else:
                    processed_args.append("--" + arg[len("--") :].replace("-", "_"))
            else:
                processed_args.append(arg)

        return super().parse_args(processed_args, namespace)


def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val


@dataclass
class xFuserArgs:
    """Arguments for xFuser engine."""

    # Model arguments
    model: str
    download_dir: Optional[str] = None
    trust_remote_code: bool = False
    # Runtime arguments
    warmup_steps: int = 1
    # use_cuda_graph: bool = True
    use_parallel_vae: bool = False
    # use_profiler: bool = False
    use_torch_compile: bool = False
    use_onediff: bool = False
    # Parallel arguments
    # data parallel
    data_parallel_degree: int = 1
    use_cfg_parallel: bool = False
    # sequence parallel
    ulysses_degree: Optional[int] = None
    ring_degree: Optional[int] = None
    # tensor parallel
    tensor_parallel_degree: int = 1
    split_scheme: Optional[str] = "row"
    # pipefusion parallel
    pipefusion_parallel_degree: int = 1
    num_pipeline_patch: Optional[int] = None
    attn_layer_num_for_pp: Optional[List[int]] = None
    # Input arguments
    height: int = 1024
    width: int = 1024
    num_frames: int = 49
    num_inference_steps: int = 20
    max_sequence_length: int = 256
    prompt: Union[str, List[str]] = ""
    negative_prompt: Union[str, List[str]] = ""
    no_use_resolution_binning: bool = False
    seed: int = 42
    output_type: str = "pil"
    enable_model_cpu_offload: bool = False
    enable_sequential_cpu_offload: bool = False
    enable_tiling: bool = False
    enable_slicing: bool = False
    # DiTFastAttn arguments
    use_fast_attn: bool = False
    n_calib: int = 8
    threshold: float = 0.5
    window_size: int = 64
    coco_path: Optional[str] = None
    use_cache: bool = False
    use_fp8_t5_encoder: bool = False

    @staticmethod
    def add_cli_args(parser: FlexibleArgumentParser):
        """Shared CLI arguments for xFuser engine."""
        # Model arguments
        model_group = parser.add_argument_group("Model Options")
        model_group.add_argument(
            "--model",
            type=str,
            default="PixArt-alpha/PixArt-XL-2-1024-MS",
            help="Name or path of the huggingface model to use.",
        )
        model_group.add_argument(
            "--download-dir",
            type=nullable_str,
            default=xFuserArgs.download_dir,
            help="Directory to download and load the weights, default to the default cache dir of huggingface.",
        )
        model_group.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Trust remote code from huggingface.",
        )

        # Runtime arguments
        runtime_group = parser.add_argument_group("Runtime Options")
        runtime_group.add_argument(
            "--warmup_steps", type=int, default=1, help="Warmup steps in generation."
        )
        # runtime_group.add_argument("--use_cuda_graph", action="store_true")
        runtime_group.add_argument("--use_parallel_vae", action="store_true")
        # runtime_group.add_argument("--use_profiler", action="store_true")
        runtime_group.add_argument(
            "--use_torch_compile",
            action="store_true",
            help="Enable torch.compile to accelerate inference in a single card",
        )
        runtime_group.add_argument(
            "--use_onediff",
            action="store_true",
            help="Enable onediff to accelerate inference in a single card",
        )

        # Parallel arguments
        parallel_group = parser.add_argument_group("Parallel Processing Options")
        parallel_group.add_argument(
            "--use_cfg_parallel",
            action="store_true",
            help="Use split batch in classifier_free_guidance. cfg_degree will be 2 if set",
        )
        parallel_group.add_argument(
            "--data_parallel_degree", type=int, default=1, help="Data parallel degree."
        )
        parallel_group.add_argument(
            "--ulysses_degree",
            type=int,
            default=None,
            help="Ulysses sequence parallel degree. Used in attention layer.",
        )
        parallel_group.add_argument(
            "--ring_degree",
            type=int,
            default=None,
            help="Ring sequence parallel degree. Used in attention layer.",
        )
        parallel_group.add_argument(
            "--pipefusion_parallel_degree",
            type=int,
            default=1,
            help="Pipefusion parallel degree. Indicates the number of pipeline stages.",
        )
        parallel_group.add_argument(
            "--num_pipeline_patch",
            type=int,
            default=None,
            help="Number of patches the feature map should be segmented in pipefusion parallel.",
        )
        parallel_group.add_argument(
            "--attn_layer_num_for_pp",
            default=None,
            nargs="*",
            type=int,
            help="List representing the number of layers per stage of the pipeline in pipefusion parallel",
        )
        parallel_group.add_argument(
            "--tensor_parallel_degree",
            type=int,
            default=1,
            help="Tensor parallel degree.",
        )
        parallel_group.add_argument(
            "--split_scheme",
            type=str,
            default="row",
            help="Split scheme for tensor parallel.",
        )

        # Input arguments
        input_group = parser.add_argument_group("Input Options")
        input_group.add_argument(
            "--height", type=int, default=1024, help="The height of image"
        )
        input_group.add_argument(
            "--width", type=int, default=1024, help="The width of image"
        )
        input_group.add_argument(
            "--num_frames", type=int, default=49, help="The frames of video"
        )
        input_group.add_argument(
            "--prompt", type=str, nargs="*", default="", help="Prompt for the model."
        )
        input_group.add_argument("--no_use_resolution_binning", action="store_true")
        input_group.add_argument(
            "--negative_prompt",
            type=str,
            nargs="*",
            default="",
            help="Negative prompt for the model.",
        )
        input_group.add_argument(
            "--num_inference_steps",
            type=int,
            default=20,
            help="Number of inference steps.",
        )
        input_group.add_argument(
            "--max_sequence_length",
            type=int,
            default=256,
            help="Max sequencen length of prompt",
        )
        runtime_group.add_argument(
            "--seed", type=int, default=42, help="Random seed for operations."
        )
        runtime_group.add_argument(
            "--output_type",
            type=str,
            default="pil",
            help="Output type of the pipeline.",
        )
        runtime_group.add_argument(
            "--enable_sequential_cpu_offload",
            action="store_true",
            help="Offloading the weights to the CPU.",
        )
        runtime_group.add_argument(
            "--enable_model_cpu_offload",
            action="store_true",
            help="Offloading the weights to the CPU.",
        )
        runtime_group.add_argument(
            "--enable_tiling",
            action="store_true",
            help="Making VAE decode a tile at a time to save GPU memory.",
        )
        runtime_group.add_argument(
            "--enable_slicing",
            action="store_true",
            help="Making VAE decode a tile at a time to save GPU memory.",
        )
        runtime_group.add_argument(
            "--use_fp8_t5_encoder",
            action="store_true",
            help="Quantize the T5 text encoder.",
        )

        # DiTFastAttn arguments
        fast_attn_group = parser.add_argument_group("DiTFastAttn Options")
        fast_attn_group.add_argument(
            "--use_fast_attn",
            action="store_true",
            help="Use DiTFastAttn to accelerate single inference. Only data parallelism can be used with DITFastAttn.",
        )
        fast_attn_group.add_argument(
            "--n_calib",
            type=int,
            default=8,
            help="Number of prompts for compression method seletion.",
        )
        fast_attn_group.add_argument(
            "--threshold",
            type=float,
            default=0.5,
            help="Threshold for selecting attention compression method.",
        )
        fast_attn_group.add_argument(
            "--window_size",
            type=int,
            default=64,
            help="Size of window attention.",
        )
        fast_attn_group.add_argument(
            "--coco_path",
            type=str,
            default=None,
            help="Path of MS COCO annotation json file.",
        )
        fast_attn_group.add_argument(
            "--use_cache",
            action="store_true",
            help="Use cache config for attention compression.",
        )

        return parser

    @classmethod
    def from_cli_args(cls, args: argparse.Namespace):
        # Get the list of attributes of this dataclass.
        attrs = [attr.name for attr in dataclasses.fields(cls)]
        # Set the attributes from the parsed arguments.
        engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
        return engine_args

    def create_config(
        self,
    ) -> Tuple[EngineConfig, InputConfig]:
        if not torch.distributed.is_initialized():
            logger.warning(
                "Distributed environment is not initialized. " "Initializing..."
            )
            init_distributed_environment()

        model_config = ModelConfig(
            model=self.model,
            download_dir=self.download_dir,
            trust_remote_code=self.trust_remote_code,
        )

        runtime_config = RuntimeConfig(
            warmup_steps=self.warmup_steps,
            # use_cuda_graph=self.use_cuda_graph,
            use_parallel_vae=self.use_parallel_vae,
            use_torch_compile=self.use_torch_compile,
            use_onediff=self.use_onediff,
            # use_profiler=self.use_profiler,
            use_fp8_t5_encoder=self.use_fp8_t5_encoder,
        )

        parallel_config = ParallelConfig(
            dp_config=DataParallelConfig(
                dp_degree=self.data_parallel_degree,
                use_cfg_parallel=self.use_cfg_parallel,
            ),
            sp_config=SequenceParallelConfig(
                ulysses_degree=self.ulysses_degree,
                ring_degree=self.ring_degree,
            ),
            tp_config=TensorParallelConfig(
                tp_degree=self.tensor_parallel_degree,
                split_scheme=self.split_scheme,
            ),
            pp_config=PipeFusionParallelConfig(
                pp_degree=self.pipefusion_parallel_degree,
                num_pipeline_patch=self.num_pipeline_patch,
                attn_layer_num_for_pp=self.attn_layer_num_for_pp,
            ),
        )

        fast_attn_config = FastAttnConfig(
            use_fast_attn=self.use_fast_attn,
            n_step=self.num_inference_steps,
            n_calib=self.n_calib,
            threshold=self.threshold,
            window_size=self.window_size,
            coco_path=self.coco_path,
            use_cache=self.use_cache,
        )

        engine_config = EngineConfig(
            model_config=model_config,
            runtime_config=runtime_config,
            parallel_config=parallel_config,
            fast_attn_config=fast_attn_config,
        )

        input_config = InputConfig(
            height=self.height,
            width=self.width,
            num_frames=self.num_frames,
            use_resolution_binning=not self.no_use_resolution_binning,
            batch_size=len(self.prompt) if isinstance(self.prompt, list) else 1,
            prompt=self.prompt,
            negative_prompt=self.negative_prompt,
            num_inference_steps=self.num_inference_steps,
            max_sequence_length=self.max_sequence_length,
            seed=self.seed,
            output_type=self.output_type,
        )

        return engine_config, input_config