compilation.py 31.6 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import enum
5
6
7
import hashlib
from collections import Counter
from dataclasses import asdict, field
8
from pathlib import Path
9
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
10

11
from pydantic import TypeAdapter, field_validator
12
13
14
15
16
17
18
19
from pydantic.dataclasses import dataclass

from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname

if TYPE_CHECKING:
20
    from vllm.config import VllmConfig
21
22
23
24
25
26
27
28
29
30
31
32
33
34
else:
    VllmConfig = object

logger = init_logger(__name__)


class CompilationLevel:
    # constants for the levels of the compilation process
    NO_COMPILATION = 0
    DYNAMO_AS_IS = 1
    DYNAMO_ONCE = 2
    PIECEWISE = 3


35
class CUDAGraphMode(enum.Enum):
36
    """Constants for the cudagraph mode in CompilationConfig.
37
38
39
    Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
    treated as concrete runtime mode for cudagraph runtime dispatching.
    """
40

41
42
43
44
45
46
    NONE = 0
    PIECEWISE = 1
    FULL = 2
    FULL_DECODE_ONLY = (FULL, NONE)
    FULL_AND_PIECEWISE = (FULL, PIECEWISE)

47
48
    def decode_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[0]) if self.separate_routine() else self
49

50
51
    def mixed_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
52
53

    def requires_piecewise_compilation(self) -> bool:
54
55
56
57
        return (
            self.decode_mode() == CUDAGraphMode.PIECEWISE
            or self.mixed_mode() == CUDAGraphMode.PIECEWISE
        )
58

59
60
    def max_cudagraph_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
61
62
63
64

    def has_full_cudagraphs(self) -> bool:
        return self.max_cudagraph_mode() == CUDAGraphMode.FULL

65
66
67
    def has_piecewise_cudagraphs(self) -> bool:
        return self.requires_piecewise_compilation()

68
69
70
    def separate_routine(self) -> bool:
        return isinstance(self.value, tuple)

71
    def valid_runtime_modes(self) -> bool:
72
        return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
73

74
75
76
    def __str__(self) -> str:
        return self.name

77

78
79
80
81
82
83
84
85
86
@config
@dataclass
class PassConfig:
    """Configuration for custom Inductor passes.

    This is separate from general `CompilationConfig` so that inductor passes
    don't all have access to full configuration - that would create a cycle as
    the `PassManager` is set as a property of config."""

87
    enable_fusion: bool = False
88
89
90
    """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
    enable_attn_fusion: bool = False
    """Whether to enable the custom attention+quant fusion pass."""
91
    enable_noop: bool = False
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    """Whether to enable the custom no-op elimination pass."""
    enable_sequence_parallelism: bool = False
    """Whether to enable sequence parallelism."""
    enable_async_tp: bool = False
    """Whether to enable async TP."""
    enable_fi_allreduce_fusion: bool = False
    """Whether to enable flashinfer allreduce fusion."""
    fi_allreduce_fusion_max_token_num: int = 16384
    """Max number of tokens to used in flashinfer allreduce fusion."""

    # TODO(luka) better pass enabling system.

    def uuid(self):
        """
        Produces a hash unique to the pass configuration.
        Any new fields that affect compilation should be added to the hash.
        Any future fields that don't affect compilation should be excluded.
        """
        return InductorPass.hash_dict(asdict(self))

    def __post_init__(self) -> None:
        if not self.enable_noop:
            if self.enable_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
117
118
                    "RMSNorm/SiluMul + quant (fp8) fusion might not work"
                )
119
120
121
            if self.enable_attn_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
122
123
                    "Attention + quant (fp8) fusion might not work"
                )
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


@config
@dataclass
class CompilationConfig:
    """Configuration for compilation. It has three parts:

    - Top-level Compilation control:
        - [`level`][vllm.config.CompilationConfig.level]
        - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
        - [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
        - [`backend`][vllm.config.CompilationConfig.backend]
        - [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
        - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
    - CudaGraph capture:
        - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
140
        - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        - [`cudagraph_capture_sizes`]
        [vllm.config.CompilationConfig.cudagraph_capture_sizes]
        - [`cudagraph_num_of_warmups`]
        [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
        - [`cudagraph_copy_inputs`]
        [vllm.config.CompilationConfig.cudagraph_copy_inputs]
        - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
    - Inductor compilation:
        - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
        - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
        - [`inductor_compile_config`]
        [vllm.config.CompilationConfig.inductor_compile_config]
        - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
        - custom inductor passes

    Why we have different sizes for cudagraph and inductor:
    - cudagraph: a cudagraph captured for a specific size can only be used
        for the same size. We need to capture all the sizes we want to use.
    - inductor: a graph compiled by inductor for a general shape can be used
        for different sizes. Inductor can also compile for specific sizes,
        where it can have more information to optimize the graph with fully
        static shapes. However, we find the general shape compilation is
        sufficient for most cases. It might be beneficial to compile for
        certain small batchsizes, where inductor is good at optimizing.
    """
166

167
168
169
170
171
172
173
174
175
176
    # Top-level Compilation control
    level: Optional[int] = None
    """The level of compilation:

    - None: If None, we will select the default compilation level.
      For V1 engine this is 3, for V0 engine this is 0.
    - 0: no compilation.
    - 1: dynamo as is.
    - 2: dynamo once.
    - 3: piecewise compilation."""
177
    debug_dump_path: Optional[Path] = None
178
179
180
181
182
    """The path to dump the debug information."""
    cache_dir: str = ""
    """The directory to store the compiled graph, to accelerate Inductor
    compilation. By default, it will use model-related information to generate
    a cache directory."""
183
    backend: str = ""
184
185
    """The backend for compilation. It needs to be a string:

186
    - "" (empty string): use the default backend.
187
188
189
190
191
192
193
194
    - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
    - "full.module.name": a qualified name which can be used to import the

    backend function.
    We use string to avoid serialization issues when using compilation in a
    distributed setting. When the compilation level is 1 or 2, the backend is
    used for the compilation directly (it sees the whole graph). When the
    compilation level is 3, the backend is used for the piecewise compilation
195
    (it sees a part of the graph)."""
196
197
198
199
200
201
202
203
204
205
206
207
    custom_ops: list[str] = field(default_factory=list)
    """Fine-grained control over which custom ops to enable/disable. Use 'all'
    to enable all, 'none' to disable all. Also specify a list of custom op
    names to enable (prefixed with a '+'), or disable (prefixed with a '-').
    Examples:

    - 'all,-op1' to enable all except op1
    - 'none,+op1,+op2' to enable only op1 and op2

    By default, all custom ops are enabled when running without Inductor and
    disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
    Inductor generates (fused) Triton kernels for disabled custom ops."""
208
    splitting_ops: Optional[list[str]] = None
209
210
211
212
    """A list of ops to split the full graph into subgraphs, used in piecewise
    compilation."""

    # Inductor capture
213
214
    use_inductor: bool = True
    """Whether to use inductor compilation:
215
216
217
218
219
220
221

    - False: inductor compilation is not used. graph runs in eager
        (custom_ops enabled by default).
    - True: inductor compilation is used (custom_ops disabled by default).
        One graph for symbolic shape and one graph per size in compile_sizes
        are compiled using configurations in inductor_compile_config.

222
    This setting is ignored if level<PIECEWISE."""
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    compile_sizes: Optional[list[Union[int, str]]] = None
    """Sizes to compile for inductor. In addition
    to integers, it also supports "cudagraph_capture_sizes" to
    specify the sizes for cudagraph capture."""
    inductor_compile_config: dict = field(default_factory=dict)
    """Additional configurations for inductor.
    - None: use default configurations."""
    inductor_passes: dict[str, str] = field(default_factory=dict)
    """Additional passes for inductor. It is a dictionary
    from pass name to pass function qualified name. We use function
    name because the config uses JSON format. If we pass the config
    from Python, functions can also be passed directly via Python object
    constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""

    # CudaGraph compilation
238
239
    cudagraph_mode: Optional[CUDAGraphMode] = None
    """
Harry Mellor's avatar
Harry Mellor committed
240
241
    The mode of the cudagraph:

242
    - NONE, no cudagraph capture.
243
    - PIECEWISE.
244
245
    - FULL.
    - FULL_DECODE_ONLY.
246
    - FULL_AND_PIECEWISE. (v1 default)
247
248

    PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
co63oc's avatar
co63oc committed
249
    incompatible ops (i.e. some attention ops) outside the cudagraph
250
251
252
253
254
255
256
257
258
259
260
261
262
    for general flexibility.

    FULL mode: Capture full cudagraph for all batches. Can be good for small
    models or workloads with small prompts; not supported by many backends.
    Generally for performance FULL_AND_PIECEWISE is better.
    
    FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
    Mixed prefill-decode batches are run without cudagraphs. Can be good for
    decode instances in a P/D setup where prefill is not as important so we
    can save some memory.
    
    FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
    piecewise cudagraph for prefill and mixed prefill-decode batches.
263
    This is the most performant mode for most models and is the default.
264
265
266
267
268
269
270
271
272
273
274

    Currently, the cudagraph mode is only used for the v1 engine.
    Note that the cudagraph logic is generally orthogonal to the 
    compilation logic. While piecewise cudagraphs require piecewise 
    compilation (level=PIECEWISE and non-empty splitting_ops), full
    cudagraphs are supported with and without compilation.
    
    Warning: This flag is new and subject to change in addition 
    more modes may be added.
    """
    use_cudagraph: bool = True
275
276
277
278
279
280
281
282
283
    """Whether to use cudagraph inside compilation.
    - False: cudagraph inside compilation is not used.
    - True: cudagraph inside compilation is used. It requires
        that all input buffers have fixed addresses, and all
        splitting ops write their outputs to input buffers.
    In the vLLM V1 Engine, this flag only applies for
    CompilationLevel.PIECEWISE (aka -O3).
    Note that this is orthogonal to the cudagraph capture logic
    outside of compilation.
284
    Warning: This flag is deprecated and will be removed in the next major or
285
286
    minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
    instead.
287
    """
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    cudagraph_num_of_warmups: int = 0
    """Number of warmup runs for cudagraph.
    It means the first several runs will be treated as warmup runs.
    Only after that, the execution will be recorded, and the recorded
    cudagraph will be used for subsequent runs."""
    cudagraph_capture_sizes: Optional[list[int]] = None
    """Sizes to capture cudagraph.
    - None (default): capture sizes are inferred from vllm config.
    - list[int]: capture sizes are specified as given."""
    cudagraph_copy_inputs: bool = False
    """Whether to copy input tensors for
    cudagraph. If the caller can guarantee that the same input buffers
    are always used, it can set this to False. Otherwise, it should
    set this to True, and the compiler will copy the input to an
302
303
304
305
    internally managed buffer. Default is False. 
    Note that this flag is only effective when cudagraph_mode is PIECEWISE.
    """
    full_cuda_graph: Optional[bool] = False
306
307
308
    """whether to use a full cuda graph for the entire forward pass rather than
    splitting certain operations such as attention into subgraphs. Thus this
    flag cannot be used together with splitting_ops. This may provide
309
310
    performance benefits for smaller models.
    Warning: This flag is deprecated and will be removed in the next major or
311
312
    minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
    FULL_AND_PIECEWISE instead.
313
    """
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    use_inductor_graph_partition: bool = False
    """Use inductor graph partition to split the graph at cudagraph_unsafe ops.
    This partition happens at inductor codegen time after all passes and fusions
    are finished. It generates a single `call` function which wraps
    cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
    outside the partition functions. For a graph with N cudagraph-unsafe ops
    (e.g., Attention), there would be N+1 partitions. To mark an op as
    cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
    register the custom op. 

    This config supports both full cudagraph and piecewise cudagraph without
    compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
    to each partition. For N+1 partitions, there would be N+1
    CUDAGraph wrapper instances.

    For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
    inductor `call` function in the model runner. The top-level full cudagraph
    capture ignores all partitioning.
    """

335
336
337
338
339
340
341
342
343
    pass_config: PassConfig = field(default_factory=PassConfig)
    """Custom inductor passes, see PassConfig for more details"""

    max_capture_size: int = field(default=None, init=False)  # type: ignore
    """not configurable, computed after init"""
    local_cache_dir: str = field(default=None, init=False)  # type: ignore
    """local cache dir for each rank"""
    bs_to_padded_graph_size: list[int] = field(
        default=None,  # type: ignore
344
345
        init=False,
    )
346
347
348
349
350
351
    """optimization:
    Intuitively, bs_to_padded_graph_size should be dict[int, int].
    since we know all keys are in a range [0, max_capture_size],
    we can optimize it to list[int] for better lookup performance."""

    # keep track of enabled and disabled custom ops
352
    enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
353
    """custom ops that are enabled"""
354
    disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
355
356
357
358
359
360
    """custom ops that are disabled"""
    traced_files: set[str] = field(default_factory=set, init=False)
    """files that are traced for compilation"""
    compilation_time: float = field(default=0.0, init=False)
    """time taken for compilation"""

361
    static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
362
363
364
365
    """Per-model forward context
    Map from layer name to layer objects that need to be accessed outside
    model code, e.g., Attention, FusedMOE when dp_size>1."""

366
367
368
369
    # Attention ops; used for piecewise cudagraphs
    _attention_ops: ClassVar[list[str]] = [
        "vllm.unified_attention",
        "vllm.unified_attention_with_output",
370
371
        "vllm.unified_mla_attention",
        "vllm.unified_mla_attention_with_output",
372
        "vllm.mamba_mixer2",
373
        "vllm.mamba_mixer",
374
        "vllm.short_conv",
375
        "vllm.linear_attention",
376
        "vllm.plamo2_mamba_mixer",
377
        "vllm.gdn_attention",
378
        "vllm.sparse_attn_indexer",
379
380
    ]

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        factors.append(self.level)
        factors.append(self.backend)
        factors.append(self.custom_ops)
        factors.append(self.splitting_ops)
        factors.append(self.use_inductor)
        factors.append(self.inductor_compile_config)
        factors.append(self.inductor_passes)
        factors.append(self.pass_config.uuid())
        return hashlib.sha256(str(factors).encode()).hexdigest()

    def __repr__(self) -> str:
        exclude = {
            "static_forward_context": True,
            "enabled_custom_ops": True,
            "disabled_custom_ops": True,
            "compilation_time": True,
            "bs_to_padded_graph_size": True,
            "traced_files": True,
            "inductor_compile_config": {
                "post_grad_custom_post_pass": True,
            },
        }

        # exclude default attr in pass_config
        pass_config_exclude = {}
        for attr, default_val in vars(PassConfig()).items():
            if getattr(self.pass_config, attr) == default_val:
                pass_config_exclude[attr] = True
        if pass_config_exclude:
            exclude["pass_config"] = pass_config_exclude

425
426
427
        config = TypeAdapter(CompilationConfig).dump_python(
            self, exclude=exclude, exclude_unset=True
        )
428
429

        return str(config)
430
431
432

    __str__ = __repr__

433
434
435
436
437
438
439
440
441
442
    @field_validator("cudagraph_mode", mode="before")
    @classmethod
    def validate_cudagraph_mode_before(cls, value: Any) -> Any:
        """
        enable parse the `cudagraph_mode` enum type from string
        """
        if isinstance(value, str):
            return CUDAGraphMode[value.upper()]
        return value

443
444
445
446
447
448
449
450
451
452
453
454
455
456
    def __post_init__(self) -> None:
        count_none = self.custom_ops.count("none")
        count_all = self.custom_ops.count("all")
        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

        # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
        # 1. A bug in PyTorch, fixed in 2.7:
        #    https://github.com/pytorch/pytorch/issues/147924
        # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
        #    work with V2. Addressing this will take extra engineering effort
        #    and it is not yet a priority. RFC here:
        #    https://github.com/vllm-project/vllm/issues/14703

        if is_torch_equal_or_newer("2.6"):
457
            KEY = "enable_auto_functionalized_v2"
458
459
460
461
462
            if KEY not in self.inductor_compile_config:
                self.inductor_compile_config[KEY] = False

        for k, v in self.inductor_passes.items():
            if not isinstance(v, str):
463
464
465
466
                assert callable(v), f"pass {k} should be callable or a qualified name"
                self.inductor_compile_config[k] = (
                    v if isinstance(v, InductorPass) else CallableInductorPass(v)
                )
467
468
469
470
471
472
473
                continue

            # resolve function from qualified name
            names = v.split(".")
            module = ".".join(names[:-1])
            func_name = names[-1]
            func = __import__(module).__dict__[func_name]
474
475
476
            self.inductor_compile_config[k] = (
                func if isinstance(func, InductorPass) else CallableInductorPass(func)
            )
477
478
479
480

        if isinstance(self.pass_config, dict):
            self.pass_config = PassConfig(**self.pass_config)

481
482
        # migrate the deprecated flags
        if not self.use_cudagraph:
483
484
485
486
487
488
489
            logger.warning(
                "use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
            )
            if (
                self.cudagraph_mode is not None
                and self.cudagraph_mode != CUDAGraphMode.NONE
            ):
490
491
492
                raise ValueError(
                    "use_cudagraph and cudagraph_mode are mutually"
                    " exclusive, prefer cudagraph_mode since "
493
494
                    "use_cudagraph is deprecated."
                )
495
496
            self.cudagraph_mode = CUDAGraphMode.NONE
        if self.full_cuda_graph:
497
498
499
500
501
502
503
504
505
506
507
508
            logger.warning(
                "full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
            )
            if (
                self.cudagraph_mode is not None
                and not self.cudagraph_mode.has_full_cudagraphs()
            ):
                raise ValueError(
                    "full_cuda_graph and cudagraph_mode are "
                    "mutually exclusive, prefer cudagraph_mode "
                    "since full_cuda_graph is deprecated."
                )
509
510
            self.cudagraph_mode = CUDAGraphMode.FULL

511
512
513
514
515
516
517
518
        if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
            "2.9.0.dev"
        ):
            raise ValueError(
                "use_inductor_graph_partition is only "
                "supported with torch>=2.9.0.dev. Set "
                "use_inductor_graph_partition=False instead."
            )
519

520
        for op in self.custom_ops:
521
522
523
524
525
526
            if op[0] not in {"+", "-"} and op not in {"all", "none"}:
                raise ValueError(
                    f"Invalid syntax '{op}' for custom op, "
                    "must be 'all', 'none', '+op' or '-op' "
                    "(where 'op' is the registered op name)"
                )
527

528
    def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
529
530
531
532
        if self.level == CompilationLevel.NO_COMPILATION:
            raise ValueError("No compilation level is set.")

        from torch._dynamo.backends.registry import list_backends
533

534
        torch_backends = list_backends(exclude_tags=tuple())
535
        if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
536
537
            if self.backend == "":
                return "eager"
538
539
540
541
            if self.backend in torch_backends:
                return self.backend
            return resolve_obj_by_qualname(self.backend)

542
543
        # TODO: pass user-specified backend to piecewise compilation
        # merge with the config use_inductor
544
545
546
        assert self.level == CompilationLevel.PIECEWISE

        from vllm.compilation.backends import VllmBackend
547

548
549
        return VllmBackend(vllm_config)

550
    def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
551
552
553
554
555
556
557
558
559
        """To complete the initialization of config,
        we need to know the cudagraph sizes."""

        if self.cudagraph_capture_sizes is None:
            self.cudagraph_capture_sizes = cudagraph_capture_sizes
        else:
            # de-duplicate the sizes provided by the config
            dedup_sizes = list(set(self.cudagraph_capture_sizes))
            if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
560
561
562
563
564
565
566
567
                logger.info(
                    (
                        "cudagraph sizes specified by model runner"
                        " %s is overridden by config %s"
                    ),
                    cudagraph_capture_sizes,
                    dedup_sizes,
                )
568
569
570
571
572
573
574
575
            self.cudagraph_capture_sizes = dedup_sizes

        computed_compile_sizes = []
        if self.compile_sizes is not None:
            # de-duplicate the sizes provided by the config
            self.compile_sizes = list(set(self.compile_sizes))
            for x in self.compile_sizes:
                if isinstance(x, str):
576
577
                    assert x == "cudagraph_capture_sizes", (
                        "Unrecognized size type in compile_sizes, "
578
                        f"expect 'cudagraph_capture_sizes', got {x}"
579
                    )
580
581
582
583
584
585
586
587
                    computed_compile_sizes.extend(self.cudagraph_capture_sizes)
                else:
                    assert isinstance(x, int)
                    computed_compile_sizes.append(x)
        self.compile_sizes = computed_compile_sizes  # type: ignore

        # sort to make sure cudagraph capture sizes are in descending order
        self.cudagraph_capture_sizes.sort(reverse=True)
588
589
590
        self.max_capture_size = (
            self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
        )
591
592

        # pre-compute the mapping from batch size to padded graph size
593
594
595
596
        self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)]
        for end, start in zip(
            self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]
        ):
597
598
599
600
601
            for bs in range(start, end):
                if bs == start:
                    self.bs_to_padded_graph_size[bs] = start
                else:
                    self.bs_to_padded_graph_size[bs] = end
602
        self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
603
604

    def set_splitting_ops_for_v1(self):
605
606
607
608
        # NOTE: this function needs to be called only when level is
        # CompilationLevel.PIECEWISE
        assert self.level == CompilationLevel.PIECEWISE, (
            "set_splitting_ops_for_v1 should only be called when "
609
610
            "level is CompilationLevel.PIECEWISE"
        )
611

612
613
614
615
616
617
618
619
        if self.use_inductor_graph_partition:
            self.set_splitting_ops_for_inductor_graph_partition()
            return

        if self.pass_config.enable_attn_fusion:
            # here use_inductor_graph_partition is False
            self.set_splitting_ops_for_attn_fusion()
            return
620

621
        if self.splitting_ops is None:
622
623
624
625
626
627
628
629
630
            # NOTE: When using full cudagraph, instead of setting an empty
            # list and capture the full cudagraph inside the flattened fx
            # graph, we keep the piecewise fx graph structure but capture
            # the full cudagraph outside the fx graph. This reduces some
            # cpu overhead when the runtime batch_size is not cudagraph
            # captured. see https://github.com/vllm-project/vllm/pull/20059
            # for details. Make a copy to avoid mutating the class-level
            # list via reference.
            self.splitting_ops = list(self._attention_ops)
631
        elif len(self.splitting_ops) == 0:
632
            logger.warning_once("Using piecewise compilation with empty splitting_ops")
633
            if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
634
                logger.warning_once(
635
                    "Piecewise compilation with empty splitting_ops do not"
636
637
638
639
                    "contains piecewise cudagraph. Setting cudagraph_"
                    "mode to NONE. Hint: If you are using attention backends "
                    "that support cudagraph, consider manually setting "
                    "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
640
641
                    "full cudagraphs."
                )
642
643
644
645
646
                self.cudagraph_mode = CUDAGraphMode.NONE
            elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
                logger.warning_once(
                    "Piecewise compilation with empty splitting_ops do not "
                    "contains piecewise cudagraph. Setting cudagraph_mode "
647
648
                    "to FULL."
                )
649
650
                self.cudagraph_mode = CUDAGraphMode.FULL
            self.splitting_ops = []
651
652
653
654
655
656

    def set_splitting_ops_for_inductor_graph_partition(self):
        assert self.use_inductor_graph_partition
        use_inductor_graph_partition_msg = (
            "When use_inductor_graph_partition=True, splitting_ops "
            "are ignored and set to an empty list. Instead, "
657
658
659
660
            '"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
            "used to annotate custom ops for graph partition."
        )
        if self.splitting_ops is not None and len(self.splitting_ops) > 0:
661
            logger.warning_once(use_inductor_graph_partition_msg)
662
663
664
665
666
        self.splitting_ops = []

    def set_splitting_ops_for_attn_fusion(self):
        assert self.pass_config.enable_attn_fusion
        if self.splitting_ops is None:
667
            self.splitting_ops = []
668
669
670
671
672
673
674
675
            if self.cudagraph_mode.has_piecewise_cudagraphs():
                logger.warning_once(
                    "enable_attn_fusion is incompatible with piecewise "
                    "cudagraph when use_inductor_graph_partition is off."
                    "In this case, splitting_ops will be set to empty "
                    "list, and cudagraph_mode will be set to FULL. "
                    "Please ensure you are using attention backends that "
                    "support cudagraph or set cudagraph_mode to NONE "
676
677
                    "explicitly if encountering any problems."
                )
678
679
680
681
                self.cudagraph_mode = CUDAGraphMode.FULL

        assert not self.splitting_ops_contain_attention(), (
            "attention ops should not be in splitting_ops "
682
683
            "when enable_attn_fusion is True"
        )
684
685
686

    def splitting_ops_contain_attention(self) -> bool:
        return self.splitting_ops is not None and all(
687
688
            op in self.splitting_ops for op in self._attention_ops
        )
689
690
691
692

    def is_attention_compiled_piecewise(self) -> bool:
        use_fx_graph_piecewise_compilation = (
            self.level == CompilationLevel.PIECEWISE
693
694
695
696
            and self.splitting_ops_contain_attention()
        )

        inductor_used = (
697
            self.level == CompilationLevel.PIECEWISE and self.use_inductor
698
699
700
        ) or (
            self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
        )
701
        use_inductor_piecewise_compilation = (
702
703
704
705
            inductor_used
            and self.use_inductor_graph_partition
            and not self.splitting_ops_contain_attention()
        )
706

707
        return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723

    def custom_op_log_check(self):
        """
        This method logs the enabled/disabled custom ops and checks that the
        passed custom_ops field only contains relevant ops.
        It is called at the end of set_current_vllm_config,
        after the custom ops have been instantiated.
        """

        if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
            logger.debug("No custom ops found in model.")
            return

        logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
        logger.debug("disabled custom ops: %s", self.disabled_custom_ops)

724
        all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops
725
726
727
728
        for op in self.custom_ops:
            if op in {"all", "none"}:
                continue

729
730
731
            assert op[0] in {"+", "-"}, (
                "Invalid custom op syntax (should be checked during init)"
            )
732
733
734
735
736
737
738
739

            # check if op name exists in model
            op_name = op[1:]
            if op_name not in all_ops_in_model:
                from vllm.model_executor.custom_op import CustomOp

                # Does op exist at all or is it just not present in this model?
                # Note: Only imported op classes appear in the registry.
740
741
742
                missing_str = (
                    "doesn't exist (or wasn't imported/registered)"
                    if op_name not in CustomOp.op_registry
743
                    else "not present in model"
744
                )
745

746
747
748
749
750
751
752
753
                enable_str = "enabling" if op[0] == "+" else "disabling"
                logger.warning_once(
                    "Op '%s' %s, %s with '%s' has no effect",
                    op_name,
                    missing_str,
                    enable_str,
                    op,
                )