compilation.py 31.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import enum
5
6
7
import hashlib
from collections import Counter
from dataclasses import asdict, field
8
from pathlib import Path
9
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
10

11
from pydantic import TypeAdapter, field_validator
12
13
14
15
16
17
18
19
from pydantic.dataclasses import dataclass

from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname

if TYPE_CHECKING:
20
    from vllm.config import VllmConfig
21
22
23
24
25
26
27
28
29
30
31
32
33
34
else:
    VllmConfig = object

logger = init_logger(__name__)


class CompilationLevel:
    # constants for the levels of the compilation process
    NO_COMPILATION = 0
    DYNAMO_AS_IS = 1
    DYNAMO_ONCE = 2
    PIECEWISE = 3


35
class CUDAGraphMode(enum.Enum):
36
    """Constants for the cudagraph mode in CompilationConfig.
37
38
39
    Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
    treated as concrete runtime mode for cudagraph runtime dispatching.
    """
40

41
42
43
44
45
46
    NONE = 0
    PIECEWISE = 1
    FULL = 2
    FULL_DECODE_ONLY = (FULL, NONE)
    FULL_AND_PIECEWISE = (FULL, PIECEWISE)

47
48
    def decode_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[0]) if self.separate_routine() else self
49

50
51
    def mixed_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
52

53
54
55
56
57
58
    def has_mode(self, mode: "CUDAGraphMode") -> bool:
        assert not mode.separate_routine()
        if self.separate_routine():
            return mode.value in self.value
        return self == mode

59
    def requires_piecewise_compilation(self) -> bool:
60
        return self.has_mode(CUDAGraphMode.PIECEWISE)
61

62
63
    def max_cudagraph_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
64
65
66
67

    def has_full_cudagraphs(self) -> bool:
        return self.max_cudagraph_mode() == CUDAGraphMode.FULL

68
69
70
    def has_piecewise_cudagraphs(self) -> bool:
        return self.requires_piecewise_compilation()

71
72
73
    def separate_routine(self) -> bool:
        return isinstance(self.value, tuple)

74
    def valid_runtime_modes(self) -> bool:
75
        return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
76

77
78
79
    def __str__(self) -> str:
        return self.name

80

81
82
83
84
85
86
87
88
89
@config
@dataclass
class PassConfig:
    """Configuration for custom Inductor passes.

    This is separate from general `CompilationConfig` so that inductor passes
    don't all have access to full configuration - that would create a cycle as
    the `PassManager` is set as a property of config."""

90
    enable_fusion: bool = False
91
92
93
    """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
    enable_attn_fusion: bool = False
    """Whether to enable the custom attention+quant fusion pass."""
94
    enable_noop: bool = False
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    """Whether to enable the custom no-op elimination pass."""
    enable_sequence_parallelism: bool = False
    """Whether to enable sequence parallelism."""
    enable_async_tp: bool = False
    """Whether to enable async TP."""
    enable_fi_allreduce_fusion: bool = False
    """Whether to enable flashinfer allreduce fusion."""
    fi_allreduce_fusion_max_token_num: int = 16384
    """Max number of tokens to used in flashinfer allreduce fusion."""

    # TODO(luka) better pass enabling system.

    def uuid(self):
        """
        Produces a hash unique to the pass configuration.
        Any new fields that affect compilation should be added to the hash.
        Any future fields that don't affect compilation should be excluded.
        """
        return InductorPass.hash_dict(asdict(self))

    def __post_init__(self) -> None:
        if not self.enable_noop:
            if self.enable_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
120
121
                    "RMSNorm/SiluMul + quant (fp8) fusion might not work"
                )
122
123
124
            if self.enable_attn_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
125
126
                    "Attention + quant (fp8) fusion might not work"
                )
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142


@config
@dataclass
class CompilationConfig:
    """Configuration for compilation. It has three parts:

    - Top-level Compilation control:
        - [`level`][vllm.config.CompilationConfig.level]
        - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
        - [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
        - [`backend`][vllm.config.CompilationConfig.backend]
        - [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
        - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
    - CudaGraph capture:
        - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
143
        - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        - [`cudagraph_capture_sizes`]
        [vllm.config.CompilationConfig.cudagraph_capture_sizes]
        - [`cudagraph_num_of_warmups`]
        [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
        - [`cudagraph_copy_inputs`]
        [vllm.config.CompilationConfig.cudagraph_copy_inputs]
        - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
    - Inductor compilation:
        - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
        - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
        - [`inductor_compile_config`]
        [vllm.config.CompilationConfig.inductor_compile_config]
        - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
        - custom inductor passes

    Why we have different sizes for cudagraph and inductor:
    - cudagraph: a cudagraph captured for a specific size can only be used
        for the same size. We need to capture all the sizes we want to use.
    - inductor: a graph compiled by inductor for a general shape can be used
        for different sizes. Inductor can also compile for specific sizes,
        where it can have more information to optimize the graph with fully
        static shapes. However, we find the general shape compilation is
        sufficient for most cases. It might be beneficial to compile for
        certain small batchsizes, where inductor is good at optimizing.
    """
169

170
171
172
173
174
175
176
177
178
179
    # Top-level Compilation control
    level: Optional[int] = None
    """The level of compilation:

    - None: If None, we will select the default compilation level.
      For V1 engine this is 3, for V0 engine this is 0.
    - 0: no compilation.
    - 1: dynamo as is.
    - 2: dynamo once.
    - 3: piecewise compilation."""
180
    debug_dump_path: Optional[Path] = None
181
182
183
184
185
    """The path to dump the debug information."""
    cache_dir: str = ""
    """The directory to store the compiled graph, to accelerate Inductor
    compilation. By default, it will use model-related information to generate
    a cache directory."""
186
    backend: str = ""
187
188
    """The backend for compilation. It needs to be a string:

189
    - "" (empty string): use the default backend.
190
191
192
193
194
195
196
197
    - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
    - "full.module.name": a qualified name which can be used to import the

    backend function.
    We use string to avoid serialization issues when using compilation in a
    distributed setting. When the compilation level is 1 or 2, the backend is
    used for the compilation directly (it sees the whole graph). When the
    compilation level is 3, the backend is used for the piecewise compilation
198
    (it sees a part of the graph)."""
199
200
201
202
203
204
205
206
207
208
209
210
    custom_ops: list[str] = field(default_factory=list)
    """Fine-grained control over which custom ops to enable/disable. Use 'all'
    to enable all, 'none' to disable all. Also specify a list of custom op
    names to enable (prefixed with a '+'), or disable (prefixed with a '-').
    Examples:

    - 'all,-op1' to enable all except op1
    - 'none,+op1,+op2' to enable only op1 and op2

    By default, all custom ops are enabled when running without Inductor and
    disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
    Inductor generates (fused) Triton kernels for disabled custom ops."""
211
    splitting_ops: Optional[list[str]] = None
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    """A list of ops to exclude from cudagraphs, used in piecewise compilation.

    The behavior depends on use_inductor_graph_partition:

    - When use_inductor_graph_partition=False (default):
        These ops are used for Dynamo FX-level graph splitting. The graph is
        split at these ops before Inductor compilation, creating separate
        subgraphs for cudagraph capture.

    - When use_inductor_graph_partition=True:
        These ops are used to register Inductor partition rules. The graph
        partitioning happens at Inductor codegen time after all passes and
        fusions are finished, allowing compilation and custom passes to operate
        on the full graph while still excluding these ops from cudagraphs.

    If None, defaults to attention ops for piecewise cudagraphs.
    If empty list [], no ops are excluded (suitable for full cudagraphs)."""
229
230

    # Inductor capture
231
232
    use_inductor: bool = True
    """Whether to use inductor compilation:
233
234
235
236
237
238
239

    - False: inductor compilation is not used. graph runs in eager
        (custom_ops enabled by default).
    - True: inductor compilation is used (custom_ops disabled by default).
        One graph for symbolic shape and one graph per size in compile_sizes
        are compiled using configurations in inductor_compile_config.

240
    This setting is ignored if level<PIECEWISE."""
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    compile_sizes: Optional[list[Union[int, str]]] = None
    """Sizes to compile for inductor. In addition
    to integers, it also supports "cudagraph_capture_sizes" to
    specify the sizes for cudagraph capture."""
    inductor_compile_config: dict = field(default_factory=dict)
    """Additional configurations for inductor.
    - None: use default configurations."""
    inductor_passes: dict[str, str] = field(default_factory=dict)
    """Additional passes for inductor. It is a dictionary
    from pass name to pass function qualified name. We use function
    name because the config uses JSON format. If we pass the config
    from Python, functions can also be passed directly via Python object
    constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""

    # CudaGraph compilation
256
257
    cudagraph_mode: Optional[CUDAGraphMode] = None
    """
Harry Mellor's avatar
Harry Mellor committed
258
259
    The mode of the cudagraph:

260
    - NONE, no cudagraph capture.
261
    - PIECEWISE.
262
263
    - FULL.
    - FULL_DECODE_ONLY.
264
    - FULL_AND_PIECEWISE. (v1 default)
265
266

    PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
co63oc's avatar
co63oc committed
267
    incompatible ops (i.e. some attention ops) outside the cudagraph
268
269
270
271
272
273
274
275
276
277
278
279
280
    for general flexibility.

    FULL mode: Capture full cudagraph for all batches. Can be good for small
    models or workloads with small prompts; not supported by many backends.
    Generally for performance FULL_AND_PIECEWISE is better.
    
    FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
    Mixed prefill-decode batches are run without cudagraphs. Can be good for
    decode instances in a P/D setup where prefill is not as important so we
    can save some memory.
    
    FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
    piecewise cudagraph for prefill and mixed prefill-decode batches.
281
    This is the most performant mode for most models and is the default.
282
283
284
285
286
287
288
289
290
291
292

    Currently, the cudagraph mode is only used for the v1 engine.
    Note that the cudagraph logic is generally orthogonal to the 
    compilation logic. While piecewise cudagraphs require piecewise 
    compilation (level=PIECEWISE and non-empty splitting_ops), full
    cudagraphs are supported with and without compilation.
    
    Warning: This flag is new and subject to change in addition 
    more modes may be added.
    """
    use_cudagraph: bool = True
293
294
295
296
297
298
299
300
301
    """Whether to use cudagraph inside compilation.
    - False: cudagraph inside compilation is not used.
    - True: cudagraph inside compilation is used. It requires
        that all input buffers have fixed addresses, and all
        splitting ops write their outputs to input buffers.
    In the vLLM V1 Engine, this flag only applies for
    CompilationLevel.PIECEWISE (aka -O3).
    Note that this is orthogonal to the cudagraph capture logic
    outside of compilation.
302
    Warning: This flag is deprecated and will be removed in the next major or
303
304
    minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
    instead.
305
    """
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    cudagraph_num_of_warmups: int = 0
    """Number of warmup runs for cudagraph.
    It means the first several runs will be treated as warmup runs.
    Only after that, the execution will be recorded, and the recorded
    cudagraph will be used for subsequent runs."""
    cudagraph_capture_sizes: Optional[list[int]] = None
    """Sizes to capture cudagraph.
    - None (default): capture sizes are inferred from vllm config.
    - list[int]: capture sizes are specified as given."""
    cudagraph_copy_inputs: bool = False
    """Whether to copy input tensors for
    cudagraph. If the caller can guarantee that the same input buffers
    are always used, it can set this to False. Otherwise, it should
    set this to True, and the compiler will copy the input to an
320
321
322
323
    internally managed buffer. Default is False. 
    Note that this flag is only effective when cudagraph_mode is PIECEWISE.
    """
    full_cuda_graph: Optional[bool] = False
324
325
326
    """whether to use a full cuda graph for the entire forward pass rather than
    splitting certain operations such as attention into subgraphs. Thus this
    flag cannot be used together with splitting_ops. This may provide
327
328
    performance benefits for smaller models.
    Warning: This flag is deprecated and will be removed in the next major or
329
330
    minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
    FULL_AND_PIECEWISE instead.
331
    """
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    use_inductor_graph_partition: bool = False
    """Use inductor graph partition to split the graph at cudagraph_unsafe ops.
    This partition happens at inductor codegen time after all passes and fusions
    are finished. It generates a single `call` function which wraps
    cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
    outside the partition functions. For a graph with N cudagraph-unsafe ops
    (e.g., Attention), there would be N+1 partitions. To mark an op as
    cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
    register the custom op. 

    This config supports both full cudagraph and piecewise cudagraph without
    compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
    to each partition. For N+1 partitions, there would be N+1
    CUDAGraph wrapper instances.

    For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
    inductor `call` function in the model runner. The top-level full cudagraph
    capture ignores all partitioning.
    """

353
354
355
356
357
358
359
360
361
    pass_config: PassConfig = field(default_factory=PassConfig)
    """Custom inductor passes, see PassConfig for more details"""

    max_capture_size: int = field(default=None, init=False)  # type: ignore
    """not configurable, computed after init"""
    local_cache_dir: str = field(default=None, init=False)  # type: ignore
    """local cache dir for each rank"""
    bs_to_padded_graph_size: list[int] = field(
        default=None,  # type: ignore
362
363
        init=False,
    )
364
365
366
367
368
369
    """optimization:
    Intuitively, bs_to_padded_graph_size should be dict[int, int].
    since we know all keys are in a range [0, max_capture_size],
    we can optimize it to list[int] for better lookup performance."""

    # keep track of enabled and disabled custom ops
370
    enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
371
    """custom ops that are enabled"""
372
    disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
373
374
375
376
377
378
    """custom ops that are disabled"""
    traced_files: set[str] = field(default_factory=set, init=False)
    """files that are traced for compilation"""
    compilation_time: float = field(default=0.0, init=False)
    """time taken for compilation"""

379
    static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
380
381
382
383
    """Per-model forward context
    Map from layer name to layer objects that need to be accessed outside
    model code, e.g., Attention, FusedMOE when dp_size>1."""

384
    # Attention ops; used for piecewise cudagraphs
385
    # Use PyTorch operator format: "namespace::name"
386
    _attention_ops: ClassVar[list[str]] = [
387
388
389
390
391
392
393
394
395
396
397
        "vllm::unified_attention",
        "vllm::unified_attention_with_output",
        "vllm::unified_mla_attention",
        "vllm::unified_mla_attention_with_output",
        "vllm::mamba_mixer2",
        "vllm::mamba_mixer",
        "vllm::short_conv",
        "vllm::linear_attention",
        "vllm::plamo2_mamba_mixer",
        "vllm::gdn_attention",
        "vllm::sparse_attn_indexer",
398
399
    ]

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        factors.append(self.level)
        factors.append(self.backend)
        factors.append(self.custom_ops)
        factors.append(self.splitting_ops)
        factors.append(self.use_inductor)
        factors.append(self.inductor_compile_config)
        factors.append(self.inductor_passes)
        factors.append(self.pass_config.uuid())
        return hashlib.sha256(str(factors).encode()).hexdigest()

    def __repr__(self) -> str:
        exclude = {
            "static_forward_context": True,
            "enabled_custom_ops": True,
            "disabled_custom_ops": True,
            "compilation_time": True,
            "bs_to_padded_graph_size": True,
            "traced_files": True,
            "inductor_compile_config": {
                "post_grad_custom_post_pass": True,
            },
        }

        # exclude default attr in pass_config
        pass_config_exclude = {}
        for attr, default_val in vars(PassConfig()).items():
            if getattr(self.pass_config, attr) == default_val:
                pass_config_exclude[attr] = True
        if pass_config_exclude:
            exclude["pass_config"] = pass_config_exclude

444
445
446
        config = TypeAdapter(CompilationConfig).dump_python(
            self, exclude=exclude, exclude_unset=True
        )
447
448

        return str(config)
449
450
451

    __str__ = __repr__

452
453
454
455
456
457
458
459
460
461
    @field_validator("cudagraph_mode", mode="before")
    @classmethod
    def validate_cudagraph_mode_before(cls, value: Any) -> Any:
        """
        enable parse the `cudagraph_mode` enum type from string
        """
        if isinstance(value, str):
            return CUDAGraphMode[value.upper()]
        return value

462
463
464
465
466
467
468
469
470
471
472
473
474
475
    def __post_init__(self) -> None:
        count_none = self.custom_ops.count("none")
        count_all = self.custom_ops.count("all")
        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

        # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
        # 1. A bug in PyTorch, fixed in 2.7:
        #    https://github.com/pytorch/pytorch/issues/147924
        # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
        #    work with V2. Addressing this will take extra engineering effort
        #    and it is not yet a priority. RFC here:
        #    https://github.com/vllm-project/vllm/issues/14703

        if is_torch_equal_or_newer("2.6"):
476
            KEY = "enable_auto_functionalized_v2"
477
478
479
480
481
            if KEY not in self.inductor_compile_config:
                self.inductor_compile_config[KEY] = False

        for k, v in self.inductor_passes.items():
            if not isinstance(v, str):
482
483
484
485
                assert callable(v), f"pass {k} should be callable or a qualified name"
                self.inductor_compile_config[k] = (
                    v if isinstance(v, InductorPass) else CallableInductorPass(v)
                )
486
487
488
489
490
491
492
                continue

            # resolve function from qualified name
            names = v.split(".")
            module = ".".join(names[:-1])
            func_name = names[-1]
            func = __import__(module).__dict__[func_name]
493
494
495
            self.inductor_compile_config[k] = (
                func if isinstance(func, InductorPass) else CallableInductorPass(func)
            )
496
497
498
499

        if isinstance(self.pass_config, dict):
            self.pass_config = PassConfig(**self.pass_config)

500
501
        # migrate the deprecated flags
        if not self.use_cudagraph:
502
503
504
505
506
507
508
            logger.warning(
                "use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
            )
            if (
                self.cudagraph_mode is not None
                and self.cudagraph_mode != CUDAGraphMode.NONE
            ):
509
510
511
                raise ValueError(
                    "use_cudagraph and cudagraph_mode are mutually"
                    " exclusive, prefer cudagraph_mode since "
512
513
                    "use_cudagraph is deprecated."
                )
514
515
            self.cudagraph_mode = CUDAGraphMode.NONE
        if self.full_cuda_graph:
516
517
518
519
520
521
522
523
524
525
526
527
            logger.warning(
                "full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
            )
            if (
                self.cudagraph_mode is not None
                and not self.cudagraph_mode.has_full_cudagraphs()
            ):
                raise ValueError(
                    "full_cuda_graph and cudagraph_mode are "
                    "mutually exclusive, prefer cudagraph_mode "
                    "since full_cuda_graph is deprecated."
                )
528
529
            self.cudagraph_mode = CUDAGraphMode.FULL

530
531
532
533
534
535
536
537
        if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
            "2.9.0.dev"
        ):
            raise ValueError(
                "use_inductor_graph_partition is only "
                "supported with torch>=2.9.0.dev. Set "
                "use_inductor_graph_partition=False instead."
            )
538

539
        for op in self.custom_ops:
540
541
542
543
544
545
            if op[0] not in {"+", "-"} and op not in {"all", "none"}:
                raise ValueError(
                    f"Invalid syntax '{op}' for custom op, "
                    "must be 'all', 'none', '+op' or '-op' "
                    "(where 'op' is the registered op name)"
                )
546

547
    def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
548
549
550
551
        if self.level == CompilationLevel.NO_COMPILATION:
            raise ValueError("No compilation level is set.")

        from torch._dynamo.backends.registry import list_backends
552

553
        torch_backends = list_backends(exclude_tags=tuple())
554
        if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
555
556
            if self.backend == "":
                return "eager"
557
558
559
560
            if self.backend in torch_backends:
                return self.backend
            return resolve_obj_by_qualname(self.backend)

561
562
        # TODO: pass user-specified backend to piecewise compilation
        # merge with the config use_inductor
563
564
565
        assert self.level == CompilationLevel.PIECEWISE

        from vllm.compilation.backends import VllmBackend
566

567
568
        return VllmBackend(vllm_config)

569
    def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
570
571
572
573
574
575
576
577
578
        """To complete the initialization of config,
        we need to know the cudagraph sizes."""

        if self.cudagraph_capture_sizes is None:
            self.cudagraph_capture_sizes = cudagraph_capture_sizes
        else:
            # de-duplicate the sizes provided by the config
            dedup_sizes = list(set(self.cudagraph_capture_sizes))
            if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
579
580
581
582
583
584
585
586
                logger.info(
                    (
                        "cudagraph sizes specified by model runner"
                        " %s is overridden by config %s"
                    ),
                    cudagraph_capture_sizes,
                    dedup_sizes,
                )
587
588
589
590
591
592
593
594
            self.cudagraph_capture_sizes = dedup_sizes

        computed_compile_sizes = []
        if self.compile_sizes is not None:
            # de-duplicate the sizes provided by the config
            self.compile_sizes = list(set(self.compile_sizes))
            for x in self.compile_sizes:
                if isinstance(x, str):
595
596
                    assert x == "cudagraph_capture_sizes", (
                        "Unrecognized size type in compile_sizes, "
597
                        f"expect 'cudagraph_capture_sizes', got {x}"
598
                    )
599
600
601
602
603
604
605
606
                    computed_compile_sizes.extend(self.cudagraph_capture_sizes)
                else:
                    assert isinstance(x, int)
                    computed_compile_sizes.append(x)
        self.compile_sizes = computed_compile_sizes  # type: ignore

        # sort to make sure cudagraph capture sizes are in descending order
        self.cudagraph_capture_sizes.sort(reverse=True)
607
608
609
        self.max_capture_size = (
            self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
        )
610
611

        # pre-compute the mapping from batch size to padded graph size
612
613
614
615
        self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)]
        for end, start in zip(
            self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]
        ):
616
617
618
619
620
            for bs in range(start, end):
                if bs == start:
                    self.bs_to_padded_graph_size[bs] = start
                else:
                    self.bs_to_padded_graph_size[bs] = end
621
        self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
622
623

    def set_splitting_ops_for_v1(self):
624
625
626
627
        # NOTE: this function needs to be called only when level is
        # CompilationLevel.PIECEWISE
        assert self.level == CompilationLevel.PIECEWISE, (
            "set_splitting_ops_for_v1 should only be called when "
628
629
            "level is CompilationLevel.PIECEWISE"
        )
630

631
632
633
634
635
636
637
638
        if self.use_inductor_graph_partition:
            self.set_splitting_ops_for_inductor_graph_partition()
            return

        if self.pass_config.enable_attn_fusion:
            # here use_inductor_graph_partition is False
            self.set_splitting_ops_for_attn_fusion()
            return
639

640
        if self.splitting_ops is None:
641
642
643
644
645
646
647
648
649
            # NOTE: When using full cudagraph, instead of setting an empty
            # list and capture the full cudagraph inside the flattened fx
            # graph, we keep the piecewise fx graph structure but capture
            # the full cudagraph outside the fx graph. This reduces some
            # cpu overhead when the runtime batch_size is not cudagraph
            # captured. see https://github.com/vllm-project/vllm/pull/20059
            # for details. Make a copy to avoid mutating the class-level
            # list via reference.
            self.splitting_ops = list(self._attention_ops)
650
        elif len(self.splitting_ops) == 0:
651
            logger.warning_once("Using piecewise compilation with empty splitting_ops")
652
            if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
653
                logger.warning_once(
654
                    "Piecewise compilation with empty splitting_ops do not"
655
656
657
658
                    "contains piecewise cudagraph. Setting cudagraph_"
                    "mode to NONE. Hint: If you are using attention backends "
                    "that support cudagraph, consider manually setting "
                    "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
659
660
                    "full cudagraphs."
                )
661
662
663
664
665
                self.cudagraph_mode = CUDAGraphMode.NONE
            elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
                logger.warning_once(
                    "Piecewise compilation with empty splitting_ops do not "
                    "contains piecewise cudagraph. Setting cudagraph_mode "
666
667
                    "to FULL."
                )
668
669
                self.cudagraph_mode = CUDAGraphMode.FULL
            self.splitting_ops = []
670
671
672

    def set_splitting_ops_for_inductor_graph_partition(self):
        assert self.use_inductor_graph_partition
673
674
        if self.splitting_ops is None:
            self.splitting_ops = list(self._attention_ops)
675
676
677

    def set_splitting_ops_for_attn_fusion(self):
        assert self.pass_config.enable_attn_fusion
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        # For dynamo-partition (non-inductor) attention fusion,
        # set splitting_ops to empty to avoid splitting at attention ops
        self.splitting_ops = []
        if self.cudagraph_mode.has_piecewise_cudagraphs():
            logger.warning_once(
                "enable_attn_fusion is incompatible with piecewise "
                "cudagraph when use_inductor_graph_partition is off. "
                "In this case, splitting_ops will be set to empty "
                "list, and cudagraph_mode will be set to FULL. "
                "Please ensure you are using attention backends that "
                "support cudagraph or set cudagraph_mode to NONE "
                "explicitly if encountering any problems."
            )
            self.cudagraph_mode = CUDAGraphMode.FULL
692
693
694

        assert not self.splitting_ops_contain_attention(), (
            "attention ops should not be in splitting_ops "
695
696
            "when enable_attn_fusion is True"
        )
697
698
699

    def splitting_ops_contain_attention(self) -> bool:
        return self.splitting_ops is not None and all(
700
701
            op in self.splitting_ops for op in self._attention_ops
        )
702
703

    def is_attention_compiled_piecewise(self) -> bool:
704
705
        if not self.splitting_ops_contain_attention():
            return False
706

707
708
709
        if not self.use_inductor_graph_partition:
            # Dynamo-level FX split case
            return self.level == CompilationLevel.PIECEWISE
710

711
712
713
714
        # Inductor partition case
        return (
            self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor"
        )
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730

    def custom_op_log_check(self):
        """
        This method logs the enabled/disabled custom ops and checks that the
        passed custom_ops field only contains relevant ops.
        It is called at the end of set_current_vllm_config,
        after the custom ops have been instantiated.
        """

        if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
            logger.debug("No custom ops found in model.")
            return

        logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
        logger.debug("disabled custom ops: %s", self.disabled_custom_ops)

731
        all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops
732
733
734
735
        for op in self.custom_ops:
            if op in {"all", "none"}:
                continue

736
737
738
            assert op[0] in {"+", "-"}, (
                "Invalid custom op syntax (should be checked during init)"
            )
739
740
741
742
743
744
745
746

            # check if op name exists in model
            op_name = op[1:]
            if op_name not in all_ops_in_model:
                from vllm.model_executor.custom_op import CustomOp

                # Does op exist at all or is it just not present in this model?
                # Note: Only imported op classes appear in the registry.
747
748
749
                missing_str = (
                    "doesn't exist (or wasn't imported/registered)"
                    if op_name not in CustomOp.op_registry
750
                    else "not present in model"
751
                )
752

753
754
755
756
757
758
759
760
                enable_str = "enabling" if op[0] == "+" else "disabling"
                logger.warning_once(
                    "Op '%s' %s, %s with '%s' has no effect",
                    op_name,
                    missing_str,
                    enable_str,
                    op,
                )