compiler_interface.py 23.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
5
6
7
import copy
import hashlib
import os
from contextlib import ExitStack
8
from typing import Any, Callable, Optional
9
10
11
12
13
14
from unittest.mock import patch

import torch
import torch._inductor.compile_fx
import torch.fx as fx

15
import vllm.envs as envs
16
from vllm.compilation.counter import compilation_counter
17
from vllm.config import VllmConfig
18
from vllm.utils import is_torch_equal_or_newer
19
20
21
22
23
24


class CompilerInterface:
    """
    The interface for a compiler that can be used by vLLM.
    """
25

26
27
28
29
    # The name of the compiler, e.g. inductor.
    # This is a class-level attribute.
    name: str

30
31
32
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
33
34
35
36
        """
        when the vLLM process uses `cache_dir` as the cache directory,
        the compiler should initialize itself with the cache directory,
        e.g. by re-directing its own cache directory to a sub-directory.
37
38
39
40
41
42
43
44

        prefix can be used in combination with cache_dir to figure out the base
        cache directory, e.g. there're multiple parts of model being compiled,
        but we want to share the same cache directory for all of them.

        e.g.
        cache_dir = "/path/to/dir/backbone", prefix = "backbone"
        cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
45
46
47
48
49
        """
        pass

    def compute_hash(self, vllm_config: VllmConfig) -> str:
        """
50
        Gather all the relevant information from the vLLM config,
51
52
        to compute a hash so that we can cache the compiled model.

53
54
        See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
        to check what information
55
56
57
58
59
60
61
62
        is already considered by default. This function should only
        consider the information that is specific to the compiler.
        """
        return ""

    def compile(
        self,
        graph: fx.GraphModule,
63
64
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
65
66
        runtime_shape: Optional[int] = None,
        key: Optional[str] = None,
67
    ) -> tuple[Optional[Callable], Optional[Any]]:
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        """
        Compile the graph with the given example inputs and compiler config,
        with a runtime shape. If the `runtime_shape` is None, it means
        the `example_inputs` have a dynamic shape. Otherwise, the
        `runtime_shape` specifies the shape of the inputs. Right now we only
        support one variable shape for all inputs, which is the batchsize
        (number of tokens) during inference.

        Dynamo will make sure `graph(*example_inputs)` is valid.

        The function should return a compiled callable function, as well as
        a handle that can be used to directly load the compiled function.

        The handle should be a plain Python object, preferably a string or a
        file path for readability.

        If the compiler doesn't support caching, it should return None for the
        handle. If the compiler fails to compile the graph, it should return
        None for the compiled function as well.
87
88
89
90

        `key` is required for StandaloneInductorAdapter, it specifies where to
        save the compiled artifact. The compiled artifact gets saved to
        `cache_dir/key`.
91
92
93
        """
        return None, None

94
95
96
97
98
99
100
101
    def load(
        self,
        handle: Any,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
        runtime_shape: Optional[int] = None,
    ) -> Callable:
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        """
        Load the compiled function from the handle.
        Raises an error if the handle is invalid.

        The handle is the second return value of the `compile` function.
        """
        raise NotImplementedError("caching is not supported")


class AlwaysHitShapeEnv:
    """
    Why do we need this class:

    For normal `torch.compile` usage, every compilation will have
    one Dynamo bytecode compilation and one Inductor compilation.
    The Inductor compilation happens under the context of the
    Dynamo bytecode compilation, and that context is used to
    determine the dynamic shape information, etc.

    For our use case, we only run Dynamo bytecode compilation once,
    and run Inductor compilation multiple times with different shapes
    plus a general shape. The compilation for specific shapes happens
    outside of the context of the Dynamo bytecode compilation. At that
    time, we don't have shape environment to provide to Inductor, and
    it will fail the Inductor code cache lookup.

    By providing a dummy shape environment that always hits, we can
    make the Inductor code cache lookup always hit, and we can
    compile the graph for different shapes as needed.

    The following dummy methods are obtained by trial-and-error
    until it works.
    """

    def __init__(self) -> None:
137
        self.guards: list[Any] = []
138
139
140
141
142
143
144
145
146
147
148

    def evaluate_guards_expression(self, *args, **kwargs):
        return True

    def get_pruned_guards(self, *args, **kwargs):
        return []

    def produce_guards_expression(self, *args, **kwargs):
        return ""


149
150
def get_inductor_factors() -> list[Any]:
    factors: list[Any] = []
151
152
    # summarize system state
    from torch._inductor.codecache import CacheBase
153

154
155
156
157
158
    system_factors = CacheBase.get_system()
    factors.append(system_factors)

    # summarize pytorch state
    from torch._inductor.codecache import torch_key
159

160
161
162
163
164
165
166
167
168
169
170
171
    torch_factors = torch_key()
    factors.append(torch_factors)
    return factors


class InductorStandaloneAdaptor(CompilerInterface):
    """
    The adaptor for the Inductor compiler.
    Requires PyTorch 2.8+.
    This is not on by default yet, but we plan to turn it on by default for
    PyTorch 2.8.

172
    Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
173
    """
174

175
176
177
178
    name = "inductor_standalone"

    def compute_hash(self, vllm_config: VllmConfig) -> str:
        factors = get_inductor_factors()
179
180
181
        hash_str = hashlib.md5(
            str(factors).encode(), usedforsecurity=False
        ).hexdigest()[:10]
182
183
        return hash_str

184
185
186
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
187
188
189
190
191
        self.cache_dir = cache_dir

    def compile(
        self,
        graph: fx.GraphModule,
192
193
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
194
195
        runtime_shape: Optional[int] = None,
        key: Optional[str] = None,
196
    ) -> tuple[Optional[Callable], Optional[Any]]:
197
        compilation_counter.num_inductor_compiles += 1
198
199
200
201
        current_config = {}
        if compiler_config is not None:
            current_config.update(compiler_config)
        set_inductor_config(current_config, runtime_shape)
202
        set_functorch_config()
203
204
205
206
207
208
209

        if isinstance(runtime_shape, int):
            dynamic_shapes = "from_example_inputs"
        else:
            dynamic_shapes = "from_tracing_context"

        from torch._inductor import standalone_compile
210

211
212
213
214
215
216
        compiled_graph = standalone_compile(
            graph,
            example_inputs,
            dynamic_shapes=dynamic_shapes,
            options={"config_patches": current_config},
        )
217
218
219
220

        # Save the compiled artifact to disk in the specified path
        assert key is not None
        path = os.path.join(self.cache_dir, key)
221
222
223
        if not envs.VLLM_DISABLE_COMPILE_CACHE:
            compiled_graph.save(path=path, format="unpacked")
            compilation_counter.num_compiled_artifacts_saved += 1
224
225
        return compiled_graph, (key, path)

226
227
228
229
230
231
232
233
    def load(
        self,
        handle: Any,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
        runtime_shape: Optional[int] = None,
    ) -> Callable:
234
235
236
237
238
        assert isinstance(handle, tuple)
        assert isinstance(handle[0], str)
        assert isinstance(handle[1], str)
        path = handle[1]
        inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
239
240
            path=path, format="unpacked"
        )
241
        from torch._inductor.compile_fx import graph_returns_tuple
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        returns_tuple = graph_returns_tuple(graph)

        def compiled_graph_wrapper(*args):
            graph_output = inductor_compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper


258
259
class InductorAdaptor(CompilerInterface):
    """
260
    The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
261
    """
262

263
264
265
    name = "inductor"

    def compute_hash(self, vllm_config: VllmConfig) -> str:
266
        factors = get_inductor_factors()
267
268
269
        hash_str = hashlib.md5(
            str(factors).encode(), usedforsecurity=False
        ).hexdigest()[:10]
270
271
        return hash_str

272
273
274
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
275
        self.cache_dir = cache_dir
276
        self.prefix = prefix
277
        self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
278
279
280
281
282
283
        if disable_cache:
            return
        # redirect the cache directory to a sub-directory
        # set flags so that Inductor and Triton store their cache
        # in the cache_dir, then users only need to copy the cache_dir
        # to another machine to reuse the cache.
284
        inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
285
286
        os.makedirs(inductor_cache, exist_ok=True)
        os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
287
        triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
288
289
290
291
292
293
        os.makedirs(triton_cache, exist_ok=True)
        os.environ["TRITON_CACHE_DIR"] = triton_cache

    def compile(
        self,
        graph: fx.GraphModule,
294
295
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
296
297
        runtime_shape: Optional[int] = None,
        key: Optional[str] = None,
298
    ) -> tuple[Optional[Callable], Optional[Any]]:
299
        compilation_counter.num_inductor_compiles += 1
300
        from torch._inductor.compile_fx import compile_fx
301

302
303
304
        current_config = {}
        if compiler_config is not None:
            current_config.update(compiler_config)
305
306
307
308
309

        # disable remote cache
        current_config["fx_graph_cache"] = True
        current_config["fx_graph_remote_cache"] = False

310
        set_inductor_config(current_config, runtime_shape)
311
        set_functorch_config()
312
313
314
315
316
317
318
319
320
321
322

        # inductor can inplace modify the graph, so we need to copy it
        # see https://github.com/pytorch/pytorch/issues/138980
        graph = copy.deepcopy(graph)

        # it's the first time we compile this graph
        # the assumption is that we don't have nested Inductor compilation.
        # compiled_fx_graph_hash will only be called once, and we can hook
        # it to get the hash of the compiled graph directly.

        hash_str, file_path = None, None
323
324
        from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash

325
326
327
328
329
330
331
        if torch.__version__.startswith("2.5"):
            original_load = FxGraphCache.load
            original_load_name = "torch._inductor.codecache.FxGraphCache.load"

            def hijack_load(*args, **kwargs):
                inductor_compiled_graph = original_load(*args, **kwargs)
                nonlocal file_path
332
333
                compiled_fn = inductor_compiled_graph.current_callable
                file_path = compiled_fn.__code__.co_filename  # noqa
334
335
336
337
                if (
                    not file_path.startswith(self.base_cache_dir)
                    and compiled_fn.__closure__ is not None
                ):
338
339
340
341
342
343
                    # hooked in the align_inputs_from_check_idxs function
                    # in torch/_inductor/utils.py
                    for cell in compiled_fn.__closure__:
                        if not callable(cell.cell_contents):
                            continue
                        if cell.cell_contents.__code__.co_filename.startswith(
344
345
                            self.base_cache_dir
                        ):
346
347
348
                            # this is the real file path compiled from Inductor
                            file_path = cell.cell_contents.__code__.co_filename
                            break
349
350
351
352
353
354
355
356
                return inductor_compiled_graph

            hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner  # noqa
        elif torch.__version__ >= "2.6":
            # function renamed in 2.6
            original_load_name = None

            def hijacked_compile_fx_inner(*args, **kwargs):
357
                output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
358
359
360
361
                nonlocal hash_str
                inductor_compiled_graph = output
                if inductor_compiled_graph is not None:
                    nonlocal file_path
362
363
                    compiled_fn = inductor_compiled_graph.current_callable
                    file_path = compiled_fn.__code__.co_filename  # noqa
364
365
366
367
                    if (
                        not file_path.startswith(self.base_cache_dir)
                        and compiled_fn.__closure__ is not None
                    ):
368
369
370
371
372
373
                        # hooked in the align_inputs_from_check_idxs function
                        # in torch/_inductor/utils.py
                        for cell in compiled_fn.__closure__:
                            if not callable(cell.cell_contents):
                                continue
                            code = cell.cell_contents.__code__
374
                            if code.co_filename.startswith(self.base_cache_dir):
375
376
377
378
                                # this is the real file path
                                # compiled from Inductor
                                file_path = code.co_filename
                                break
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
                    hash_str = inductor_compiled_graph._fx_graph_cache_key
                return output

        def hijack_compiled_fx_graph_hash(*args, **kwargs):
            out = compiled_fx_graph_hash(*args, **kwargs)
            nonlocal hash_str
            hash_str = out[0]
            return out

        def _check_can_cache(*args, **kwargs):
            # no error means it can be cached.
            # Inductor refuses to cache the graph outside of Dynamo
            # tracing context, and also disables caching for graphs
            # with high-order ops.
            # For vLLM, in either case, we want to cache the graph.
            # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
            return

        def _get_shape_env() -> AlwaysHitShapeEnv:
            return AlwaysHitShapeEnv()

        with ExitStack() as stack:
            # hijack to get the compiled graph itself
            if original_load_name is not None:
                stack.enter_context(patch(original_load_name, hijack_load))

            # for hijacking the hash of the compiled graph
            stack.enter_context(
407
408
409
410
411
                patch(
                    "torch._inductor.codecache.compiled_fx_graph_hash",
                    hijack_compiled_fx_graph_hash,
                )
            )
412
413
414

            # for providing a dummy shape environment
            stack.enter_context(
415
416
417
418
419
                patch(
                    "torch._inductor.codecache.FxGraphCache._get_shape_env",
                    _get_shape_env,
                )
            )
420

421
            from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
422
423
424
425
426
427

            # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
            if hasattr(AOTAutogradCache, "_get_shape_env"):
                stack.enter_context(
                    patch(
                        "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
428
429
430
                        _get_shape_env,
                    )
                )
431

432
433
434
435
            # for forcing the graph to be cached
            stack.enter_context(
                patch(
                    "torch._inductor.codecache.FxGraphCache._check_can_cache",
436
437
438
                    _check_can_cache,
                )
            )
439

440
441
442
            # Dynamo metrics context, see method for more details.
            stack.enter_context(self.metrics_context())

443
444
445
446
447
448
449
450
            # Disable remote caching. When these are on, on remote cache-hit,
            # the monkey-patched functions never actually get called.
            # vLLM today assumes and requires the monkey-patched functions to
            # get hit.
            # TODO(zou3519): we're going to replace this all with
            # standalone_compile sometime.
            if is_torch_equal_or_newer("2.6"):
                stack.enter_context(
451
452
                    torch._inductor.config.patch(fx_graph_remote_cache=False)
                )
453
454
455
456
457
                # InductorAdaptor (unfortunately) requires AOTAutogradCache
                # to be turned off to run. It will fail to acquire the hash_str
                # and error if not.
                # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
                stack.enter_context(
458
459
                    torch._functorch.config.patch(enable_autograd_cache=False)
                )
460
                stack.enter_context(
461
462
                    torch._functorch.config.patch(enable_remote_autograd_cache=False)
                )
463

464
465
466
467
468
469
            compiled_graph = compile_fx(
                graph,
                example_inputs,
                inner_compile=hijacked_compile_fx_inner,
                config_patches=current_config,
            )
470

471
472
473
474
        # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
        # compilation cache. So turn off the checks if we disable the
        # compilation cache.
        if not envs.VLLM_DISABLE_COMPILE_CACHE:
475
476
477
478
479
480
481
            if hash_str is None:
                raise RuntimeError(
                    "vLLM failed to compile the model. The most "
                    "likely reason for this is that a previous compilation "
                    "failed, leading to a corrupted compilation artifact. "
                    "We recommend trying to "
                    "remove ~/.cache/vllm/torch_compile_cache and try again "
482
483
                    "to see the real issue. "
                )
484
            assert file_path is not None, (
485
486
                "failed to get the file path of the compiled graph"
            )
487
488
        return compiled_graph, (hash_str, file_path)

489
490
491
492
493
494
495
496
    def load(
        self,
        handle: Any,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
        runtime_shape: Optional[int] = None,
    ) -> Callable:
497
498
499
500
501
        assert isinstance(handle, tuple)
        assert isinstance(handle[0], str)
        assert isinstance(handle[1], str)
        hash_str = handle[0]

502
        from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
503
        from torch._inductor.codecache import FxGraphCache
504

505
506
        with ExitStack() as exit_stack:
            exit_stack.enter_context(
507
508
509
510
511
                patch(
                    "torch._inductor.codecache.FxGraphCache._get_shape_env",
                    lambda *args, **kwargs: AlwaysHitShapeEnv(),
                )
            )
512
513
514
515
516
            # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
            if hasattr(AOTAutogradCache, "_get_shape_env"):
                exit_stack.enter_context(
                    patch(
                        "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
517
518
519
                        lambda *args, **kwargs: AlwaysHitShapeEnv(),
                    )
                )
520
521
522
523

            # Dynamo metrics context, see method for more details.
            exit_stack.enter_context(self.metrics_context())

524
525
            if torch.__version__.startswith("2.5"):
                inductor_compiled_graph = FxGraphCache._lookup_graph(
526
527
                    hash_str, example_inputs, True, False
                )
528
529
530
531
532
                assert inductor_compiled_graph is not None, (
                    "Inductor cache lookup failed. Please remove"
                    f"the cache directory and try again."  # noqa
                )
            elif torch.__version__ >= "2.6":
533
534
                from torch._inductor.output_code import CompiledFxGraphConstantsWithGm

535
536
                constants = CompiledFxGraphConstantsWithGm(graph)
                inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
537
538
                    hash_str, example_inputs, True, None, constants
                )
539
540
541
542
543
544
545
546
547
548
549
550
                assert inductor_compiled_graph is not None, (
                    "Inductor cache lookup failed. Please remove"
                    f"the cache directory and try again."  # noqa
                )

        # Inductor calling convention (function signature):
        # f(list) -> tuple
        # Dynamo calling convention (function signature):
        # f(*args) -> Any

        # need to know if the graph returns a tuple
        from torch._inductor.compile_fx import graph_returns_tuple
551

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        returns_tuple = graph_returns_tuple(graph)

        # this is the callable we return to Dynamo to run
        def compiled_graph(*args):
            # convert args to list
            list_args = list(args)
            graph_output = inductor_compiled_graph(list_args)
            # unpack the tuple if needed
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    def metrics_context(self) -> contextlib.AbstractContextManager:
        """
        This method returns the Dynamo metrics context (if it exists,
        otherwise a null context). It is used by various compile components.
        Present in torch>=2.6, it's used inside FxGraphCache in
        torch==2.6 (but not after). It might also be used in various other
        torch.compile internal functions.

        Because it is re-entrant, we always set it (even if entering via Dynamo
        and the context was already entered). We might want to revisit if it
        should be set at a different level of compilation.

        This is likely a bug in PyTorch: public APIs should not rely on
        manually setting up internal contexts. But we also rely on non-public
        APIs which might not provide these guarantees.
        """
583
        if is_torch_equal_or_newer("2.6"):
584
            import torch._dynamo.utils
585

586
587
588
589
            return torch._dynamo.utils.get_metrics_context()
        else:
            return contextlib.nullcontext()

590

591
592
593
594
def set_inductor_config(config, runtime_shape):
    if isinstance(runtime_shape, int):
        # for a specific batchsize, tuning triton kernel parameters
        # can be beneficial
595
596
        config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
        config["coordinate_descent_tuning"] = (
597
598
            envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
        )
599
600


601
602
603
604
def set_functorch_config():
    torch._functorch.config.bundled_autograd_cache = False


605
606
607
608
609
610
class EagerAdaptor(CompilerInterface):
    name = "eager"

    def compile(
        self,
        graph: fx.GraphModule,
611
612
        example_inputs: list[Any],
        compiler_config: dict[str, Any],
613
614
        runtime_shape: Optional[int] = None,
        key: Optional[str] = None,
615
    ) -> tuple[Optional[Callable], Optional[Any]]:
616
        compilation_counter.num_eager_compiles += 1
617
618
619
        # we don't need to compile the graph, just return the graph itself.
        # It does not support caching, return None for the handle.
        return graph, None