"vllm/perf/benchmark_throughput.py" did not exist on "93872128111499f2d9127c3cd1f94b35850187d1"
compilation.py 36.1 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import enum
5
6
import hashlib
from collections import Counter
7
from collections.abc import Callable
8
from dataclasses import asdict, field
9
from pathlib import Path
10
from typing import TYPE_CHECKING, Any, ClassVar
11

12
from pydantic import TypeAdapter, field_validator
13
14
15
16
17
from pydantic.dataclasses import dataclass

from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
18
from vllm.platforms import current_platform
19
20
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname
21
22

if TYPE_CHECKING:
23
    from vllm.config import VllmConfig
24
25
26
27
28
29
else:
    VllmConfig = object

logger = init_logger(__name__)


30
31
32
33
34
35
36
37
38
39
40
41
42
43
class CompilationMode:
    """The compilation approach used for torch.compile-based compilation of the
    model."""

    NONE = 0
    """No torch.compile compilation is applied, model runs in fully eager pytorch mode.
    The model runs as-is."""
    STOCK_TORCH_COMPILE = 1
    """The standard `torch.compile` compilation pipeline."""
    DYNAMO_TRACE_ONCE = 2
    """Single Dynamo trace through the model, avoiding recompilation."""
    VLLM_COMPILE = 3
    """Custom vLLM Inductor-based backend with caching, piecewise compilation,
    shape specialization, and custom passes."""
44
45


46
class CUDAGraphMode(enum.Enum):
47
    """Constants for the cudagraph mode in CompilationConfig.
48
49
50
    Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
    treated as concrete runtime mode for cudagraph runtime dispatching.
    """
51

52
53
54
55
56
57
    NONE = 0
    PIECEWISE = 1
    FULL = 2
    FULL_DECODE_ONLY = (FULL, NONE)
    FULL_AND_PIECEWISE = (FULL, PIECEWISE)

58
59
    def decode_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[0]) if self.separate_routine() else self
60

61
62
    def mixed_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
63

64
65
66
67
68
69
    def has_mode(self, mode: "CUDAGraphMode") -> bool:
        assert not mode.separate_routine()
        if self.separate_routine():
            return mode.value in self.value
        return self == mode

70
    def requires_piecewise_compilation(self) -> bool:
71
        return self.has_mode(CUDAGraphMode.PIECEWISE)
72

73
74
    def max_cudagraph_mode(self) -> "CUDAGraphMode":
        return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
75
76
77
78

    def has_full_cudagraphs(self) -> bool:
        return self.max_cudagraph_mode() == CUDAGraphMode.FULL

79
80
81
    def has_piecewise_cudagraphs(self) -> bool:
        return self.requires_piecewise_compilation()

82
83
84
    def separate_routine(self) -> bool:
        return isinstance(self.value, tuple)

85
    def valid_runtime_modes(self) -> bool:
86
        return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
87

88
89
90
    def __str__(self) -> str:
        return self.name

91

92
93
94
95
96
97
98
99
100
@config
@dataclass
class PassConfig:
    """Configuration for custom Inductor passes.

    This is separate from general `CompilationConfig` so that inductor passes
    don't all have access to full configuration - that would create a cycle as
    the `PassManager` is set as a property of config."""

101
    enable_fusion: bool = False
102
103
104
    """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
    enable_attn_fusion: bool = False
    """Whether to enable the custom attention+quant fusion pass."""
105
    enable_noop: bool = False
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    """Whether to enable the custom no-op elimination pass."""
    enable_sequence_parallelism: bool = False
    """Whether to enable sequence parallelism."""
    enable_async_tp: bool = False
    """Whether to enable async TP."""
    enable_fi_allreduce_fusion: bool = False
    """Whether to enable flashinfer allreduce fusion."""
    fi_allreduce_fusion_max_token_num: int = 16384
    """Max number of tokens to used in flashinfer allreduce fusion."""

    # TODO(luka) better pass enabling system.

    def uuid(self):
        """
        Produces a hash unique to the pass configuration.
        Any new fields that affect compilation should be added to the hash.
        Any future fields that don't affect compilation should be excluded.
        """
        return InductorPass.hash_dict(asdict(self))

    def __post_init__(self) -> None:
        if not self.enable_noop:
            if self.enable_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
131
132
                    "RMSNorm/SiluMul + quant (fp8) fusion might not work"
                )
133
134
135
            if self.enable_attn_fusion:
                logger.warning_once(
                    "Fusion enabled but reshape elimination disabled. "
136
137
                    "Attention + quant (fp8) fusion might not work"
                )
138
139
140
141
142
143
144
145


@config
@dataclass
class CompilationConfig:
    """Configuration for compilation. It has three parts:

    - Top-level Compilation control:
146
        - [`mode`][vllm.config.CompilationConfig.mode]
147
148
149
150
151
152
153
        - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
        - [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
        - [`backend`][vllm.config.CompilationConfig.backend]
        - [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
        - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
    - CudaGraph capture:
        - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
154
        - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        - [`cudagraph_capture_sizes`]
        [vllm.config.CompilationConfig.cudagraph_capture_sizes]
        - [`cudagraph_num_of_warmups`]
        [vllm.config.CompilationConfig.cudagraph_num_of_warmups]
        - [`cudagraph_copy_inputs`]
        [vllm.config.CompilationConfig.cudagraph_copy_inputs]
        - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
    - Inductor compilation:
        - [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
        - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
        - [`inductor_compile_config`]
        [vllm.config.CompilationConfig.inductor_compile_config]
        - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
        - custom inductor passes

    Why we have different sizes for cudagraph and inductor:
    - cudagraph: a cudagraph captured for a specific size can only be used
        for the same size. We need to capture all the sizes we want to use.
    - inductor: a graph compiled by inductor for a general shape can be used
        for different sizes. Inductor can also compile for specific sizes,
        where it can have more information to optimize the graph with fully
        static shapes. However, we find the general shape compilation is
        sufficient for most cases. It might be beneficial to compile for
        certain small batchsizes, where inductor is good at optimizing.
    """
180

181
    # Top-level Compilation control
182
    level: int | None = None
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    """
    Level is deprecated and will be removed in the next release,
    either 0.12.0 or 0.11.2 whichever is soonest.
    Please use mode. Currently all levels are mapped to mode.
    """
    # Top-level Compilation control
    mode: int | None = None
    """The compilation approach used for torch.compile-based compilation of the
    model.

    - None: If None, we will select the default compilation mode.
      For V1 engine this is 3.
    - 0: NONE: No torch.compile compilation is applied, model runs in fully
         eager pytorch mode. The model runs as-is.
    - 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
    - 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding
         recompilation by removing guards.
         Requires no dynamic-shape-dependent control-flow.
    - 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching,
         piecewise compilation, shape specialization, and custom passes."""
203
    debug_dump_path: Path | None = None
204
205
206
207
208
    """The path to dump the debug information."""
    cache_dir: str = ""
    """The directory to store the compiled graph, to accelerate Inductor
    compilation. By default, it will use model-related information to generate
    a cache directory."""
209
    backend: str = ""
210
211
    """The backend for compilation. It needs to be a string:

212
213
    - "" (empty string): use the default backend ("inductor" on CUDA-alike
    platforms).
214
215
216
217
218
    - "eager"/"openxla"/...: use the specified backend registered in PyTorch.
    - "full.module.name": a qualified name which can be used to import the

    backend function.
    We use string to avoid serialization issues when using compilation in a
219
    distributed setting. When the compilation mode is 1 or 2, the backend is
220
    used for the compilation directly (it sees the whole graph). When the
221
    compilation mode is 3, the backend is used for the piecewise compilation
222
    (it sees a part of the graph). The backend can not be custom for compilation
223
    mode 3, i.e. the backend must be either eager or inductor. Furthermore,
224
    compilation is only piecewise if splitting ops is set accordingly and
225
    use_inductor_graph_partition is off. Note that the default options for
226
227
    splitting ops are sufficient for piecewise compilation.
    """
228
229
230
231
232
233
234
235
236
237
    custom_ops: list[str] = field(default_factory=list)
    """Fine-grained control over which custom ops to enable/disable. Use 'all'
    to enable all, 'none' to disable all. Also specify a list of custom op
    names to enable (prefixed with a '+'), or disable (prefixed with a '-').
    Examples:

    - 'all,-op1' to enable all except op1
    - 'none,+op1,+op2' to enable only op1 and op2

    By default, all custom ops are enabled when running without Inductor and
238
    disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
239
    Inductor generates (fused) Triton kernels for disabled custom ops."""
240
    splitting_ops: list[str] | None = None
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    """A list of ops to exclude from cudagraphs, used in piecewise compilation.

    The behavior depends on use_inductor_graph_partition:

    - When use_inductor_graph_partition=False (default):
        These ops are used for Dynamo FX-level graph splitting. The graph is
        split at these ops before Inductor compilation, creating separate
        subgraphs for cudagraph capture.

    - When use_inductor_graph_partition=True:
        These ops are used to register Inductor partition rules. The graph
        partitioning happens at Inductor codegen time after all passes and
        fusions are finished, allowing compilation and custom passes to operate
        on the full graph while still excluding these ops from cudagraphs.

    If None, defaults to attention ops for piecewise cudagraphs.
    If empty list [], no ops are excluded (suitable for full cudagraphs)."""
258
259

    # Inductor capture
260
261
262
263
264
265
    use_inductor: bool | None = None
    """
    Whether to use inductor compilation.

    This flag is deprecated and will be removed in the next release 0.12.0.
    Please use the 'backend' option instead.
266
267
268
269
270
271
272

    - False: inductor compilation is not used. graph runs in eager
        (custom_ops enabled by default).
    - True: inductor compilation is used (custom_ops disabled by default).
        One graph for symbolic shape and one graph per size in compile_sizes
        are compiled using configurations in inductor_compile_config.

273
    This setting is ignored if mode<VLLM_COMPILE.
274
275
276
277

    For future compatibility:
    If use_inductor is True, backend="inductor" otherwise backend="eager".
    """
278
    compile_sizes: list[int | str] | None = None
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    """Sizes to compile for inductor. In addition
    to integers, it also supports "cudagraph_capture_sizes" to
    specify the sizes for cudagraph capture."""
    inductor_compile_config: dict = field(default_factory=dict)
    """Additional configurations for inductor.
    - None: use default configurations."""
    inductor_passes: dict[str, str] = field(default_factory=dict)
    """Additional passes for inductor. It is a dictionary
    from pass name to pass function qualified name. We use function
    name because the config uses JSON format. If we pass the config
    from Python, functions can also be passed directly via Python object
    constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""

    # CudaGraph compilation
293
    cudagraph_mode: CUDAGraphMode | None = None
294
    """
Harry Mellor's avatar
Harry Mellor committed
295
296
    The mode of the cudagraph:

297
    - NONE, no cudagraph capture.
298
    - PIECEWISE.
299
300
    - FULL.
    - FULL_DECODE_ONLY.
301
    - FULL_AND_PIECEWISE. (v1 default)
302
303

    PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
co63oc's avatar
co63oc committed
304
    incompatible ops (i.e. some attention ops) outside the cudagraph
305
306
307
308
309
310
311
312
313
314
315
316
317
    for general flexibility.

    FULL mode: Capture full cudagraph for all batches. Can be good for small
    models or workloads with small prompts; not supported by many backends.
    Generally for performance FULL_AND_PIECEWISE is better.
    
    FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
    Mixed prefill-decode batches are run without cudagraphs. Can be good for
    decode instances in a P/D setup where prefill is not as important so we
    can save some memory.
    
    FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
    piecewise cudagraph for prefill and mixed prefill-decode batches.
318
    This is the most performant mode for most models and is the default.
319
320
321
322

    Currently, the cudagraph mode is only used for the v1 engine.
    Note that the cudagraph logic is generally orthogonal to the 
    compilation logic. While piecewise cudagraphs require piecewise 
323
    compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full
324
325
326
327
328
329
    cudagraphs are supported with and without compilation.
    
    Warning: This flag is new and subject to change in addition 
    more modes may be added.
    """
    use_cudagraph: bool = True
330
331
332
333
334
335
    """Whether to use cudagraph inside compilation.
    - False: cudagraph inside compilation is not used.
    - True: cudagraph inside compilation is used. It requires
        that all input buffers have fixed addresses, and all
        splitting ops write their outputs to input buffers.
    In the vLLM V1 Engine, this flag only applies for
336
    CompilationMode.VLLM_COMPILE (aka -O3).
337
338
    Note that this is orthogonal to the cudagraph capture logic
    outside of compilation.
339
    Warning: This flag is deprecated and will be removed in the next major or
340
341
    minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE
    instead.
342
    """
343
344
345
346
347
    cudagraph_num_of_warmups: int = 0
    """Number of warmup runs for cudagraph.
    It means the first several runs will be treated as warmup runs.
    Only after that, the execution will be recorded, and the recorded
    cudagraph will be used for subsequent runs."""
348
    cudagraph_capture_sizes: list[int] | None = None
349
350
351
352
353
354
355
356
    """Sizes to capture cudagraph.
    - None (default): capture sizes are inferred from vllm config.
    - list[int]: capture sizes are specified as given."""
    cudagraph_copy_inputs: bool = False
    """Whether to copy input tensors for
    cudagraph. If the caller can guarantee that the same input buffers
    are always used, it can set this to False. Otherwise, it should
    set this to True, and the compiler will copy the input to an
357
358
359
    internally managed buffer. Default is False. 
    Note that this flag is only effective when cudagraph_mode is PIECEWISE.
    """
360
    full_cuda_graph: bool | None = False
361
362
363
    """whether to use a full cuda graph for the entire forward pass rather than
    splitting certain operations such as attention into subgraphs. Thus this
    flag cannot be used together with splitting_ops. This may provide
364
365
    performance benefits for smaller models.
    Warning: This flag is deprecated and will be removed in the next major or
366
367
    minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
    FULL_AND_PIECEWISE instead.
368
    """
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    use_inductor_graph_partition: bool = False
    """Use inductor graph partition to split the graph at cudagraph_unsafe ops.
    This partition happens at inductor codegen time after all passes and fusions
    are finished. It generates a single `call` function which wraps
    cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
    outside the partition functions. For a graph with N cudagraph-unsafe ops
    (e.g., Attention), there would be N+1 partitions. To mark an op as
    cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
    register the custom op. 

    This config supports both full cudagraph and piecewise cudagraph without
    compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
    to each partition. For N+1 partitions, there would be N+1
    CUDAGraph wrapper instances.

    For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
    inductor `call` function in the model runner. The top-level full cudagraph
    capture ignores all partitioning.
    """

390
391
392
393
394
395
396
397
398
    pass_config: PassConfig = field(default_factory=PassConfig)
    """Custom inductor passes, see PassConfig for more details"""

    max_capture_size: int = field(default=None, init=False)  # type: ignore
    """not configurable, computed after init"""
    local_cache_dir: str = field(default=None, init=False)  # type: ignore
    """local cache dir for each rank"""
    bs_to_padded_graph_size: list[int] = field(
        default=None,  # type: ignore
399
400
        init=False,
    )
401
402
403
404
405
406
    """optimization:
    Intuitively, bs_to_padded_graph_size should be dict[int, int].
    since we know all keys are in a range [0, max_capture_size],
    we can optimize it to list[int] for better lookup performance."""

    # keep track of enabled and disabled custom ops
407
    enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
408
    """custom ops that are enabled"""
409
    disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
410
411
412
413
414
415
    """custom ops that are disabled"""
    traced_files: set[str] = field(default_factory=set, init=False)
    """files that are traced for compilation"""
    compilation_time: float = field(default=0.0, init=False)
    """time taken for compilation"""

416
    static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
417
418
419
420
    """Per-model forward context
    Map from layer name to layer objects that need to be accessed outside
    model code, e.g., Attention, FusedMOE when dp_size>1."""

421
    # Attention ops; used for piecewise cudagraphs
422
    # Use PyTorch operator format: "namespace::name"
423
    _attention_ops: ClassVar[list[str]] = [
424
425
426
427
428
429
430
431
432
433
434
        "vllm::unified_attention",
        "vllm::unified_attention_with_output",
        "vllm::unified_mla_attention",
        "vllm::unified_mla_attention_with_output",
        "vllm::mamba_mixer2",
        "vllm::mamba_mixer",
        "vllm::short_conv",
        "vllm::linear_attention",
        "vllm::plamo2_mamba_mixer",
        "vllm::gdn_attention",
        "vllm::sparse_attn_indexer",
435
436
    ]

437
438
439
440
441
442
443
444
445
446
447
448
449
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
450
        factors.append(self.mode)
451
452
453
454
        factors.append(self.backend)
        factors.append(self.custom_ops)
        factors.append(self.splitting_ops)
        factors.append(self.use_inductor)
455
        factors.append(self.use_inductor_graph_partition)
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        factors.append(self.inductor_compile_config)
        factors.append(self.inductor_passes)
        factors.append(self.pass_config.uuid())
        return hashlib.sha256(str(factors).encode()).hexdigest()

    def __repr__(self) -> str:
        exclude = {
            "static_forward_context": True,
            "enabled_custom_ops": True,
            "disabled_custom_ops": True,
            "compilation_time": True,
            "bs_to_padded_graph_size": True,
            "traced_files": True,
            "inductor_compile_config": {
                "post_grad_custom_post_pass": True,
            },
        }

        # exclude default attr in pass_config
        pass_config_exclude = {}
        for attr, default_val in vars(PassConfig()).items():
            if getattr(self.pass_config, attr) == default_val:
                pass_config_exclude[attr] = True
        if pass_config_exclude:
            exclude["pass_config"] = pass_config_exclude

482
483
484
        config = TypeAdapter(CompilationConfig).dump_python(
            self, exclude=exclude, exclude_unset=True
        )
485
486

        return str(config)
487
488
489

    __str__ = __repr__

490
491
492
493
494
495
496
497
498
499
    @field_validator("cudagraph_mode", mode="before")
    @classmethod
    def validate_cudagraph_mode_before(cls, value: Any) -> Any:
        """
        enable parse the `cudagraph_mode` enum type from string
        """
        if isinstance(value, str):
            return CUDAGraphMode[value.upper()]
        return value

500
    def __post_init__(self) -> None:
501
502
503
504
505
506
507
508
509
510
511
        if self.level is not None:
            logger.warning(
                "Level is deprecated and will be removed in the next release,"
                "either 0.12.0 or 0.11.2 whichever is soonest."
                "Use mode instead."
                "If both level and mode are given,"
                "only mode will be used."
            )
            if self.mode is None:
                self.mode = self.level

512
513
514
515
516
517
518
519
520
521
522
523
524
        count_none = self.custom_ops.count("none")
        count_all = self.custom_ops.count("all")
        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"

        # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2:
        # 1. A bug in PyTorch, fixed in 2.7:
        #    https://github.com/pytorch/pytorch/issues/147924
        # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't
        #    work with V2. Addressing this will take extra engineering effort
        #    and it is not yet a priority. RFC here:
        #    https://github.com/vllm-project/vllm/issues/14703

        if is_torch_equal_or_newer("2.6"):
525
            KEY = "enable_auto_functionalized_v2"
526
527
528
529
530
            if KEY not in self.inductor_compile_config:
                self.inductor_compile_config[KEY] = False

        for k, v in self.inductor_passes.items():
            if not isinstance(v, str):
531
532
533
534
                assert callable(v), f"pass {k} should be callable or a qualified name"
                self.inductor_compile_config[k] = (
                    v if isinstance(v, InductorPass) else CallableInductorPass(v)
                )
535
536
537
538
539
540
541
                continue

            # resolve function from qualified name
            names = v.split(".")
            module = ".".join(names[:-1])
            func_name = names[-1]
            func = __import__(module).__dict__[func_name]
542
543
544
            self.inductor_compile_config[k] = (
                func if isinstance(func, InductorPass) else CallableInductorPass(func)
            )
545
546
547
548

        if isinstance(self.pass_config, dict):
            self.pass_config = PassConfig(**self.pass_config)

549
550
551
552
553
554
555
556
557
558
        if (
            is_torch_equal_or_newer("2.9.0.dev")
            and "combo_kernels" not in self.inductor_compile_config
            and "benchmark_combo_kernel" not in self.inductor_compile_config
        ):
            # use horizontal fusion, which is useful for fusing qk-norm and
            # qk-rope when query and key have different shapes.
            self.inductor_compile_config["combo_kernels"] = True
            self.inductor_compile_config["benchmark_combo_kernel"] = True

559
560
        # migrate the deprecated flags
        if not self.use_cudagraph:
561
562
563
564
565
566
567
            logger.warning(
                "use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
            )
            if (
                self.cudagraph_mode is not None
                and self.cudagraph_mode != CUDAGraphMode.NONE
            ):
568
569
570
                raise ValueError(
                    "use_cudagraph and cudagraph_mode are mutually"
                    " exclusive, prefer cudagraph_mode since "
571
572
                    "use_cudagraph is deprecated."
                )
573
574
            self.cudagraph_mode = CUDAGraphMode.NONE
        if self.full_cuda_graph:
575
576
577
578
579
580
581
582
583
584
585
586
            logger.warning(
                "full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
            )
            if (
                self.cudagraph_mode is not None
                and not self.cudagraph_mode.has_full_cudagraphs()
            ):
                raise ValueError(
                    "full_cuda_graph and cudagraph_mode are "
                    "mutually exclusive, prefer cudagraph_mode "
                    "since full_cuda_graph is deprecated."
                )
587
588
            self.cudagraph_mode = CUDAGraphMode.FULL

589
590
591
592
593
594
595
596
        if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
            "2.9.0.dev"
        ):
            raise ValueError(
                "use_inductor_graph_partition is only "
                "supported with torch>=2.9.0.dev. Set "
                "use_inductor_graph_partition=False instead."
            )
597

598
        for op in self.custom_ops:
599
600
601
602
603
604
            if op[0] not in {"+", "-"} and op not in {"all", "none"}:
                raise ValueError(
                    f"Invalid syntax '{op}' for custom op, "
                    "must be 'all', 'none', '+op' or '-op' "
                    "(where 'op' is the registered op name)"
                )
605

606
607
608
        # Currently only eager and inductor backend are supported.
        # for piecewise compilation. Custom backends are not suppported for
        # piecewise compilation. Update when more backends are supported.
609
        if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
            "",
            "eager",
            "inductor",
        ]:
            raise ValueError(
                f"Invalid backend for piecewise compilation: {self.backend}"
            )

        if self.use_inductor is not None:
            logger.warning_once(
                "The 'use_inductor' flag is deprecated and will be "
                "removed in the next release (v0.12.0). "
                "Please use the 'backend' option instead.",
            )
            self.backend = "inductor" if self.use_inductor else "eager"

        if self.backend == "":
            self.backend = current_platform.simple_compile_backend

629
    def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
630
631
632
633
634
635
636
        """
        Initialize the backend for the compilation config from a vllm config.
        Arguments:
            vllm_config: The vllm config to initialize the backend from.
        Returns:
            The backend for the compilation config.
        """
637
        if self.mode is None:
638
            raise ValueError(
639
                "No compilation mode is set. This method should only be \
640
641
642
                called via vllm config where the level is set if none is \
                provided."
            )
643
644
        if self.mode == CompilationMode.NONE:
            raise ValueError("No compilation mode is set.")
645
646

        from torch._dynamo.backends.registry import list_backends
647

648
        torch_backends = list_backends(exclude_tags=tuple())
649
650
651
652
        if self.mode in [
            CompilationMode.STOCK_TORCH_COMPILE,
            CompilationMode.DYNAMO_TRACE_ONCE,
        ]:
653
654
655
656
            if self.backend in torch_backends:
                return self.backend
            return resolve_obj_by_qualname(self.backend)

657
        assert self.mode == CompilationMode.VLLM_COMPILE
658
659
660
661
        if self.backend not in ["eager", "inductor"]:
            raise ValueError(
                f"Invalid backend for piecewise compilation: {self.backend}"
            )
662
663

        from vllm.compilation.backends import VllmBackend
664

665
666
        return VllmBackend(vllm_config)

667
    def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None:
668
669
670
671
672
673
674
675
676
        """To complete the initialization of config,
        we need to know the cudagraph sizes."""

        if self.cudagraph_capture_sizes is None:
            self.cudagraph_capture_sizes = cudagraph_capture_sizes
        else:
            # de-duplicate the sizes provided by the config
            dedup_sizes = list(set(self.cudagraph_capture_sizes))
            if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
677
678
679
680
681
682
683
684
                logger.info(
                    (
                        "cudagraph sizes specified by model runner"
                        " %s is overridden by config %s"
                    ),
                    cudagraph_capture_sizes,
                    dedup_sizes,
                )
685
686
687
688
689
690
691
692
            self.cudagraph_capture_sizes = dedup_sizes

        computed_compile_sizes = []
        if self.compile_sizes is not None:
            # de-duplicate the sizes provided by the config
            self.compile_sizes = list(set(self.compile_sizes))
            for x in self.compile_sizes:
                if isinstance(x, str):
693
694
                    assert x == "cudagraph_capture_sizes", (
                        "Unrecognized size type in compile_sizes, "
695
                        f"expect 'cudagraph_capture_sizes', got {x}"
696
                    )
697
698
699
700
701
702
703
704
                    computed_compile_sizes.extend(self.cudagraph_capture_sizes)
                else:
                    assert isinstance(x, int)
                    computed_compile_sizes.append(x)
        self.compile_sizes = computed_compile_sizes  # type: ignore

        # sort to make sure cudagraph capture sizes are in descending order
        self.cudagraph_capture_sizes.sort(reverse=True)
705
706
707
        self.max_capture_size = (
            self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
        )
708
709

        # pre-compute the mapping from batch size to padded graph size
710
711
712
713
        self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)]
        for end, start in zip(
            self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]
        ):
714
715
716
717
718
            for bs in range(start, end):
                if bs == start:
                    self.bs_to_padded_graph_size[bs] = start
                else:
                    self.bs_to_padded_graph_size[bs] = end
719
        self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
720
721

    def set_splitting_ops_for_v1(self):
722
723
724
        # NOTE: this function needs to be called only when mode is
        # CompilationMode.VLLM_COMPILE
        assert self.mode == CompilationMode.VLLM_COMPILE, (
725
            "set_splitting_ops_for_v1 should only be called when "
726
            "mode is CompilationMode.VLLM_COMPILE"
727
        )
728

729
730
731
732
733
734
735
736
        if self.use_inductor_graph_partition:
            self.set_splitting_ops_for_inductor_graph_partition()
            return

        if self.pass_config.enable_attn_fusion:
            # here use_inductor_graph_partition is False
            self.set_splitting_ops_for_attn_fusion()
            return
737

738
        if self.splitting_ops is None:
739
740
741
742
743
744
745
746
747
            # NOTE: When using full cudagraph, instead of setting an empty
            # list and capture the full cudagraph inside the flattened fx
            # graph, we keep the piecewise fx graph structure but capture
            # the full cudagraph outside the fx graph. This reduces some
            # cpu overhead when the runtime batch_size is not cudagraph
            # captured. see https://github.com/vllm-project/vllm/pull/20059
            # for details. Make a copy to avoid mutating the class-level
            # list via reference.
            self.splitting_ops = list(self._attention_ops)
748
        elif len(self.splitting_ops) == 0:
749
            logger.warning_once("Using piecewise compilation with empty splitting_ops")
750
            if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
751
                logger.warning_once(
752
                    "Piecewise compilation with empty splitting_ops do not"
753
754
755
756
                    "contains piecewise cudagraph. Setting cudagraph_"
                    "mode to NONE. Hint: If you are using attention backends "
                    "that support cudagraph, consider manually setting "
                    "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
757
758
                    "full cudagraphs."
                )
759
760
761
762
763
                self.cudagraph_mode = CUDAGraphMode.NONE
            elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
                logger.warning_once(
                    "Piecewise compilation with empty splitting_ops do not "
                    "contains piecewise cudagraph. Setting cudagraph_mode "
764
765
                    "to FULL."
                )
766
767
                self.cudagraph_mode = CUDAGraphMode.FULL
            self.splitting_ops = []
768
769
770

    def set_splitting_ops_for_inductor_graph_partition(self):
        assert self.use_inductor_graph_partition
771
772
        if self.splitting_ops is None:
            self.splitting_ops = list(self._attention_ops)
773
774
775

    def set_splitting_ops_for_attn_fusion(self):
        assert self.pass_config.enable_attn_fusion
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        # For dynamo-partition (non-inductor) attention fusion,
        # set splitting_ops to empty to avoid splitting at attention ops
        self.splitting_ops = []
        if self.cudagraph_mode.has_piecewise_cudagraphs():
            logger.warning_once(
                "enable_attn_fusion is incompatible with piecewise "
                "cudagraph when use_inductor_graph_partition is off. "
                "In this case, splitting_ops will be set to empty "
                "list, and cudagraph_mode will be set to FULL. "
                "Please ensure you are using attention backends that "
                "support cudagraph or set cudagraph_mode to NONE "
                "explicitly if encountering any problems."
            )
            self.cudagraph_mode = CUDAGraphMode.FULL
790
791
792

        assert not self.splitting_ops_contain_attention(), (
            "attention ops should not be in splitting_ops "
793
794
            "when enable_attn_fusion is True"
        )
795
796
797

    def splitting_ops_contain_attention(self) -> bool:
        return self.splitting_ops is not None and all(
798
799
            op in self.splitting_ops for op in self._attention_ops
        )
800
801

    def is_attention_compiled_piecewise(self) -> bool:
802
803
        if not self.splitting_ops_contain_attention():
            return False
804

805
806
        if not self.use_inductor_graph_partition:
            # Dynamo-level FX split case
807
            return self.mode == CompilationMode.VLLM_COMPILE
808

809
        # Inductor partition case
810
        return self.backend == "inductor" and self.mode > CompilationMode.NONE
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826

    def custom_op_log_check(self):
        """
        This method logs the enabled/disabled custom ops and checks that the
        passed custom_ops field only contains relevant ops.
        It is called at the end of set_current_vllm_config,
        after the custom ops have been instantiated.
        """

        if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
            logger.debug("No custom ops found in model.")
            return

        logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
        logger.debug("disabled custom ops: %s", self.disabled_custom_ops)

827
        all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops
828
829
830
831
        for op in self.custom_ops:
            if op in {"all", "none"}:
                continue

832
833
834
            assert op[0] in {"+", "-"}, (
                "Invalid custom op syntax (should be checked during init)"
            )
835
836
837
838
839
840
841
842

            # check if op name exists in model
            op_name = op[1:]
            if op_name not in all_ops_in_model:
                from vllm.model_executor.custom_op import CustomOp

                # Does op exist at all or is it just not present in this model?
                # Note: Only imported op classes appear in the registry.
843
844
845
                missing_str = (
                    "doesn't exist (or wasn't imported/registered)"
                    if op_name not in CustomOp.op_registry
846
                    else "not present in model"
847
                )
848

849
850
851
852
853
854
855
856
                enable_str = "enabling" if op[0] == "+" else "disabling"
                logger.warning_once(
                    "Op '%s' %s, %s with '%s' has no effect",
                    op_name,
                    missing_str,
                    enable_str,
                    op,
                )