compiler_interface.py 26.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
5
import copy
import os
6
from collections.abc import Callable
7
from contextlib import ExitStack
8
from typing import Any, Literal
9
10
11
12
13
14
from unittest.mock import patch

import torch
import torch._inductor.compile_fx
import torch.fx as fx

15
import vllm.envs as envs
16
from vllm.compilation.counter import compilation_counter
17
from vllm.config import VllmConfig
18
from vllm.config.utils import Range
19
from vllm.logger import init_logger
20
from vllm.utils.hashing import safe_hash
21
from vllm.utils.torch_utils import is_torch_equal_or_newer
22

23
24
logger = init_logger(__name__)

25
26
27
28
29

class CompilerInterface:
    """
    The interface for a compiler that can be used by vLLM.
    """
30

31
32
33
34
    # The name of the compiler, e.g. inductor.
    # This is a class-level attribute.
    name: str

35
36
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
37
    ) -> None:
38
39
40
41
        """
        when the vLLM process uses `cache_dir` as the cache directory,
        the compiler should initialize itself with the cache directory,
        e.g. by re-directing its own cache directory to a sub-directory.
42
43
44
45
46
47
48
49

        prefix can be used in combination with cache_dir to figure out the base
        cache directory, e.g. there're multiple parts of model being compiled,
        but we want to share the same cache directory for all of them.

        e.g.
        cache_dir = "/path/to/dir/backbone", prefix = "backbone"
        cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
50
51
52
53
54
        """
        pass

    def compute_hash(self, vllm_config: VllmConfig) -> str:
        """
55
        Gather all the relevant information from the vLLM config,
56
57
        to compute a hash so that we can cache the compiled model.

58
59
        See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
        to check what information
60
61
62
63
64
65
66
67
        is already considered by default. This function should only
        consider the information that is specific to the compiler.
        """
        return ""

    def compile(
        self,
        graph: fx.GraphModule,
68
69
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
70
        compile_range: Range,
71
        key: str | None = None,
72
    ) -> tuple[Callable[..., Any] | None, Any | None]:
73
74
        """
        Compile the graph with the given example inputs and compiler config,
75
76
77
78
79
        with a range. The `compile_range` specifies the range of the inputs,
        it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
        or a range [5, 8].
        Right now we only support one variable in ranges for all inputs,
         which is the batchsize (number of tokens) during inference.
80
81
82
83
84
85
86
87
88
89
90
91

        Dynamo will make sure `graph(*example_inputs)` is valid.

        The function should return a compiled callable function, as well as
        a handle that can be used to directly load the compiled function.

        The handle should be a plain Python object, preferably a string or a
        file path for readability.

        If the compiler doesn't support caching, it should return None for the
        handle. If the compiler fails to compile the graph, it should return
        None for the compiled function as well.
92
93
94
95

        `key` is required for StandaloneInductorAdapter, it specifies where to
        save the compiled artifact. The compiled artifact gets saved to
        `cache_dir/key`.
96
97
98
        """
        return None, None

99
100
101
102
103
    def load(
        self,
        handle: Any,
        graph: fx.GraphModule,
        example_inputs: list[Any],
104
        graph_index: int,
105
        compile_range: Range,
106
    ) -> Callable[..., Any]:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        """
        Load the compiled function from the handle.
        Raises an error if the handle is invalid.

        The handle is the second return value of the `compile` function.
        """
        raise NotImplementedError("caching is not supported")


class AlwaysHitShapeEnv:
    """
    Why do we need this class:

    For normal `torch.compile` usage, every compilation will have
    one Dynamo bytecode compilation and one Inductor compilation.
    The Inductor compilation happens under the context of the
    Dynamo bytecode compilation, and that context is used to
    determine the dynamic shape information, etc.

    For our use case, we only run Dynamo bytecode compilation once,
    and run Inductor compilation multiple times with different shapes
    plus a general shape. The compilation for specific shapes happens
    outside of the context of the Dynamo bytecode compilation. At that
    time, we don't have shape environment to provide to Inductor, and
    it will fail the Inductor code cache lookup.

    By providing a dummy shape environment that always hits, we can
    make the Inductor code cache lookup always hit, and we can
    compile the graph for different shapes as needed.

    The following dummy methods are obtained by trial-and-error
    until it works.
    """

    def __init__(self) -> None:
142
        self.guards: list[Any] = []
143

144
    def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
145
146
        return True

147
    def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
148
149
        return []

150
    def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
151
152
153
        return ""


154
155
def get_inductor_factors() -> list[Any]:
    factors: list[Any] = []
156
157
    # summarize system state
    from torch._inductor.codecache import CacheBase
158

159
160
161
162
163
    system_factors = CacheBase.get_system()
    factors.append(system_factors)

    # summarize pytorch state
    from torch._inductor.codecache import torch_key
164

165
166
167
168
169
    torch_factors = torch_key()
    factors.append(torch_factors)
    return factors


170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def is_compile_cache_enabled(
    vllm_additional_inductor_config: dict[str, Any],
) -> bool:
    vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
        "force_disable_caches", False
    )

    # TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
    # with torch.compiler.config.force_disable_caches when minimum PyTorch
    # version reaches 2.10
    return (
        not envs.VLLM_DISABLE_COMPILE_CACHE
        and not torch._inductor.config.force_disable_caches
        and not vllm_inductor_config_disable_cache
    )


187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def _patch_standalone_compile_atomic_save() -> None:
    """Backport of pytorch/pytorch#162432 for torch < 2.10.0.

    Patches CompiledArtifact.save() to use write_atomic for binary format,
    preventing corrupt cache files when multiple processes compile
    concurrently.
    """
    from torch._inductor.codecache import write_atomic
    from torch._inductor.standalone_compile import CompiledArtifact as cls

    if getattr(cls.save, "_vllm_patched", False):
        return

    original_save = cls.save

    def _save(
        self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
    ) -> None:
        if format != "binary":
            return original_save(self, path=path, format=format)
        from torch._dynamo.utils import dynamo_timed
        from torch._inductor.codecache import torch_key
        from torch.utils._appending_byte_serializer import BytesWriter

        with dynamo_timed("CompiledArtifact.save"):
            assert self._artifacts is not None
            artifact_bytes, cache_info = self._artifacts
            assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
            key = cache_info.aot_autograd_artifacts[0]
            assert not os.path.isdir(path)
            writer = BytesWriter()
            writer.write_bytes(torch_key())
            writer.write_str(key)
            writer.write_bytes(artifact_bytes)
            write_atomic(path, writer.to_bytes())

    _save._vllm_patched = True  # type: ignore[attr-defined]
    cls.save = _save  # type: ignore[assignment]
    logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)


228
229
230
231
232
233
234
class InductorStandaloneAdaptor(CompilerInterface):
    """
    The adaptor for the Inductor compiler.
    Requires PyTorch 2.8+.
    This is not on by default yet, but we plan to turn it on by default for
    PyTorch 2.8.

235
    Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
236
    """
237

238
239
    name = "inductor_standalone"

240
    def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
241
242
        if not is_torch_equal_or_newer("2.10.0"):
            _patch_standalone_compile_atomic_save()
243
244
        self.save_format = save_format

245
246
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        factors = get_inductor_factors()
247
248
249
        hash_str: str = safe_hash(
            str(factors).encode(), usedforsecurity=False
        ).hexdigest()[:10]
250
251
        return hash_str

252
253
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
254
    ) -> None:
255
256
257
258
259
        self.cache_dir = cache_dir

    def compile(
        self,
        graph: fx.GraphModule,
260
261
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
262
        compile_range: Range,
263
        key: str | None = None,
264
    ) -> tuple[Callable[..., Any] | None, Any | None]:
265
        compilation_counter.num_inductor_compiles += 1
266
267
268
        current_config = {}
        if compiler_config is not None:
            current_config.update(compiler_config)
269
        set_inductor_config(current_config, compile_range)
270
        set_functorch_config()
271

272
        if compile_range.is_single_size():
273
274
            dynamic_shapes = "from_example_inputs"
        else:
275
            dynamic_shapes = "from_graph"
276
277

        from torch._inductor import standalone_compile
278

279
        supports_aot = is_torch_equal_or_newer("2.10.0")
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

        if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            logger.error(
                "CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
                "is enabled but PyTorch version does not support 'aot' "
                "parameter in standalone_compile. This requires PyTorch "
                "2.10.0+. Falling back to non-AOT mode."
            )

        compile_kwargs = {
            "dynamic_shapes": dynamic_shapes,
            "options": {
                "config_patches": current_config,
            },
        }

        use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
        # only add 'aot' parameter if both supported and enabled...
        # this will set bundled_autograd_cache
        # https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
        if use_aot:
            compile_kwargs["aot"] = True  # type: ignore[assignment]

303
304
305
306
307
308
309
310
311
312
313
314
315
316
        # Inductor's pre-grad passes don't do anything for vLLM.
        # The pre-grad passes get run even on cache-hit and negatively impact
        # vllm cold compile times by O(1s)
        # Can remove this after the following issue gets fixed
        # https://github.com/pytorch/pytorch/issues/174502
        if envs.VLLM_ENABLE_PREGRAD_PASSES:
            ctx: Any = contextlib.nullcontext()
        else:
            ctx = patch(
                "torch._inductor.compile_fx._recursive_pre_grad_passes",
                lambda gm, _: gm,
            )
        with ctx:
            compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
317
318
319
320
321
322
323
324
325
326
327

        if use_aot:
            from torch._inductor.standalone_compile import AOTCompiledArtifact

            assert isinstance(compiled_graph, AOTCompiledArtifact)
            assert hasattr(compiled_graph, "serialize")
            # just return the compiled graph and a key
            # since we can serialize the bytes using to_bytes
            # and reload it using the key when reading
            return compiled_graph, None

328
329
330
        # Save the compiled artifact to disk in the specified path
        assert key is not None
        path = os.path.join(self.cache_dir, key)
331

332
333
334
335
336
337
338
        def is_saveable_2_10(compiled_artifact):
            # can just use compiled_artifact.is_saveable in 2.11
            if compiled_artifact._artifacts is None:
                return False
            _, cache_info = compiled_artifact._artifacts
            return len(cache_info.aot_autograd_artifacts) == 1

339
        if is_compile_cache_enabled(compiler_config):
340
341
342
343
344
345
346
347
348
349
350
351
            if not is_saveable_2_10(compiled_graph):
                raise RuntimeError(
                    "The compiled artifact is not serializable. This usually means "
                    "that the model code has something that is not serializable "
                    "by torch.compile in it. You can fix this by either "
                    "figuring out what is not serializable and rewriting it, "
                    "filing a bug report, "
                    "or suppressing this error by "
                    "disabling vLLM's compilation cache via "
                    "VLLM_DISABLE_COMPILE_CACHE=1 "
                    "(this will greatly increase vLLM server warm start times)."
                )
352
            compiled_graph.save(path=path, format=self.save_format)
353
            compilation_counter.num_compiled_artifacts_saved += 1
354
355
        return compiled_graph, (key, path)

356
357
358
359
360
    def load(
        self,
        handle: Any,
        graph: fx.GraphModule,
        example_inputs: list[Any],
361
        graph_index: int,
362
        compile_range: Range,
363
    ) -> Callable[..., Any]:
364
365
366
367
368
        assert isinstance(handle, tuple)
        assert isinstance(handle[0], str)
        assert isinstance(handle[1], str)
        path = handle[1]
        inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
369
            path=path, format=self.save_format
370
        )
371
        from torch._inductor.compile_fx import graph_returns_tuple
372

373
374
        returns_tuple = graph_returns_tuple(graph)

375
        def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
376
377
378
379
380
381
382
383
384
385
386
387
            graph_output = inductor_compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper


388
389
class InductorAdaptor(CompilerInterface):
    """
390
    The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
391
    """
392

393
394
395
    name = "inductor"

    def compute_hash(self, vllm_config: VllmConfig) -> str:
396
        factors = get_inductor_factors()
397
398
399
        hash_str: str = safe_hash(
            str(factors).encode(), usedforsecurity=False
        ).hexdigest()[:10]
400
401
        return hash_str

402
403
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
404
    ) -> None:
405
        self.cache_dir = cache_dir
406
        self.prefix = prefix
407
        self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
408
409
        if disable_cache:
            return
410
        # redirect the cache directory to a subdirectory
411
412
413
        # set flags so that Inductor and Triton store their cache
        # in the cache_dir, then users only need to copy the cache_dir
        # to another machine to reuse the cache.
414
        inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
415
416
        os.makedirs(inductor_cache, exist_ok=True)
        os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
417
        triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
418
419
420
421
422
423
        os.makedirs(triton_cache, exist_ok=True)
        os.environ["TRITON_CACHE_DIR"] = triton_cache

    def compile(
        self,
        graph: fx.GraphModule,
424
425
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
426
        compile_range: Range,
427
        key: str | None = None,
428
    ) -> tuple[Callable[..., Any] | None, Any | None]:
429
        compilation_counter.num_inductor_compiles += 1
430
        from torch._inductor.compile_fx import compile_fx
431

432
433
434
        current_config = {}
        if compiler_config is not None:
            current_config.update(compiler_config)
435
436
437
438
439

        # disable remote cache
        current_config["fx_graph_cache"] = True
        current_config["fx_graph_remote_cache"] = False

440
        set_inductor_config(current_config, compile_range)
441
        set_functorch_config()
442
443
444
445
446
447
448
449
450
451
452

        # inductor can inplace modify the graph, so we need to copy it
        # see https://github.com/pytorch/pytorch/issues/138980
        graph = copy.deepcopy(graph)

        # it's the first time we compile this graph
        # the assumption is that we don't have nested Inductor compilation.
        # compiled_fx_graph_hash will only be called once, and we can hook
        # it to get the hash of the compiled graph directly.

        hash_str, file_path = None, None
453
        from torch._inductor.codecache import compiled_fx_graph_hash
454

455
456
457
458
459
        def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
            output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
            nonlocal hash_str
            inductor_compiled_graph = output
            if inductor_compiled_graph is not None:
460
                nonlocal file_path
461
462
                compiled_fn = inductor_compiled_graph.current_callable
                file_path = compiled_fn.__code__.co_filename  # noqa
463
464
465
466
                if (
                    not file_path.startswith(self.base_cache_dir)
                    and compiled_fn.__closure__ is not None
                ):
467
468
469
470
471
                    # hooked in the align_inputs_from_check_idxs function
                    # in torch/_inductor/utils.py
                    for cell in compiled_fn.__closure__:
                        if not callable(cell.cell_contents):
                            continue
472
473
474
475
476
                        code = cell.cell_contents.__code__
                        if code.co_filename.startswith(self.base_cache_dir):
                            # this is the real file path
                            # compiled from Inductor
                            file_path = code.co_filename
477
                            break
478
479
                hash_str = inductor_compiled_graph._fx_graph_cache_key
            return output
480

481
        def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
482
483
484
485
486
            out = compiled_fx_graph_hash(*args, **kwargs)
            nonlocal hash_str
            hash_str = out[0]
            return out

487
        def _check_can_cache(*args: Any, **kwargs: Any) -> None:
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            # no error means it can be cached.
            # Inductor refuses to cache the graph outside of Dynamo
            # tracing context, and also disables caching for graphs
            # with high-order ops.
            # For vLLM, in either case, we want to cache the graph.
            # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
            return

        def _get_shape_env() -> AlwaysHitShapeEnv:
            return AlwaysHitShapeEnv()

        with ExitStack() as stack:
            # for hijacking the hash of the compiled graph
            stack.enter_context(
502
503
504
505
506
                patch(
                    "torch._inductor.codecache.compiled_fx_graph_hash",
                    hijack_compiled_fx_graph_hash,
                )
            )
507
508
509

            # for providing a dummy shape environment
            stack.enter_context(
510
511
512
513
514
                patch(
                    "torch._inductor.codecache.FxGraphCache._get_shape_env",
                    _get_shape_env,
                )
            )
515

516
            from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
517
518
519
520
521
522

            # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
            if hasattr(AOTAutogradCache, "_get_shape_env"):
                stack.enter_context(
                    patch(
                        "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
523
524
525
                        _get_shape_env,
                    )
                )
526

527
528
529
530
            # for forcing the graph to be cached
            stack.enter_context(
                patch(
                    "torch._inductor.codecache.FxGraphCache._check_can_cache",
531
532
533
                    _check_can_cache,
                )
            )
534

535
536
537
            # Dynamo metrics context, see method for more details.
            stack.enter_context(self.metrics_context())

538
539
540
541
542
543
            # Disable remote caching. When these are on, on remote cache-hit,
            # the monkey-patched functions never actually get called.
            # vLLM today assumes and requires the monkey-patched functions to
            # get hit.
            # TODO(zou3519): we're going to replace this all with
            # standalone_compile sometime.
544
545
546
547
548
549
550
551
552
553
554
555
556
            stack.enter_context(
                torch._inductor.config.patch(fx_graph_remote_cache=False)
            )
            # InductorAdaptor (unfortunately) requires AOTAutogradCache
            # to be turned off to run. It will fail to acquire the hash_str
            # and error if not.
            # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
            stack.enter_context(
                torch._functorch.config.patch(enable_autograd_cache=False)
            )
            stack.enter_context(
                torch._functorch.config.patch(enable_remote_autograd_cache=False)
            )
557

558
559
560
561
562
563
            compiled_graph = compile_fx(
                graph,
                example_inputs,
                inner_compile=hijacked_compile_fx_inner,
                config_patches=current_config,
            )
564

565
566
        # Turn off the checks if we disable the compilation cache.
        if is_compile_cache_enabled(compiler_config):
567
568
569
570
571
572
573
            if hash_str is None:
                raise RuntimeError(
                    "vLLM failed to compile the model. The most "
                    "likely reason for this is that a previous compilation "
                    "failed, leading to a corrupted compilation artifact. "
                    "We recommend trying to "
                    "remove ~/.cache/vllm/torch_compile_cache and try again "
574
575
                    "to see the real issue. "
                )
576
            assert file_path is not None, (
577
578
                "failed to get the file path of the compiled graph"
            )
579
580
        return compiled_graph, (hash_str, file_path)

581
582
583
584
585
    def load(
        self,
        handle: Any,
        graph: fx.GraphModule,
        example_inputs: list[Any],
586
        graph_index: int,
587
        compile_range: Range,
588
    ) -> Callable[..., Any]:
589
590
591
592
593
        assert isinstance(handle, tuple)
        assert isinstance(handle[0], str)
        assert isinstance(handle[1], str)
        hash_str = handle[0]

594
        from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
595
        from torch._inductor.codecache import FxGraphCache
596

597
598
        with ExitStack() as exit_stack:
            exit_stack.enter_context(
599
600
601
602
603
                patch(
                    "torch._inductor.codecache.FxGraphCache._get_shape_env",
                    lambda *args, **kwargs: AlwaysHitShapeEnv(),
                )
            )
604
605
606
607
608
            # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
            if hasattr(AOTAutogradCache, "_get_shape_env"):
                exit_stack.enter_context(
                    patch(
                        "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
609
610
611
                        lambda *args, **kwargs: AlwaysHitShapeEnv(),
                    )
                )
612
613
614
615

            # Dynamo metrics context, see method for more details.
            exit_stack.enter_context(self.metrics_context())

616
            from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
617

618
619
620
621
622
623
624
625
            constants = CompiledFxGraphConstantsWithGm(graph)
            inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
                hash_str, example_inputs, True, None, constants
            )
            assert inductor_compiled_graph is not None, (
                "Inductor cache lookup failed. Please remove "
                f"the cache directory and try again."  # noqa
            )
626
627
628
629
630
631
632
633

        # Inductor calling convention (function signature):
        # f(list) -> tuple
        # Dynamo calling convention (function signature):
        # f(*args) -> Any

        # need to know if the graph returns a tuple
        from torch._inductor.compile_fx import graph_returns_tuple
634

635
636
637
        returns_tuple = graph_returns_tuple(graph)

        # this is the callable we return to Dynamo to run
638
        def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
639
640
641
642
643
644
645
646
647
648
649
            # convert args to list
            list_args = list(args)
            graph_output = inductor_compiled_graph(list_args)
            # unpack the tuple if needed
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph

650
    def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
651
652
653
654
655
656
657
658
659
        """
        This method returns the Dynamo metrics context (if it exists,
        otherwise a null context). It is used by various compile components.
        Present in torch>=2.6, it's used inside FxGraphCache in
        torch==2.6 (but not after). It might also be used in various other
        torch.compile internal functions.

        Because it is re-entrant, we always set it (even if entering via Dynamo
        and the context was already entered). We might want to revisit if it
660
        should be set at a different mode of compilation.
661
662
663
664
665

        This is likely a bug in PyTorch: public APIs should not rely on
        manually setting up internal contexts. But we also rely on non-public
        APIs which might not provide these guarantees.
        """
666
        if is_torch_equal_or_newer("2.6"):
667
            import torch._dynamo.utils
668

669
            return torch._dynamo.utils.get_metrics_context()  # type: ignore[no-any-return]
670
671
672
        else:
            return contextlib.nullcontext()

673

674
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
675
676
    if compile_range.is_single_size():
        # for a specific batch size, tuning triton kernel parameters
677
        # can be beneficial
678
679
        config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
        config["coordinate_descent_tuning"] = (
680
681
            envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
        )
682
683


684
def set_functorch_config() -> None:
685
686
    if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
        torch._functorch.config.bundled_autograd_cache = False
687
688


689
690
691
692
693
694
class EagerAdaptor(CompilerInterface):
    name = "eager"

    def compile(
        self,
        graph: fx.GraphModule,
695
696
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
697
        compile_range: Range,
698
        key: str | None = None,
699
    ) -> tuple[Callable[..., Any] | None, Any | None]:
700
        compilation_counter.num_eager_compiles += 1
701
702
703
        # we don't need to compile the graph, just return the graph itself.
        # It does not support caching, return None for the handle.
        return graph, None