decorators.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import hashlib
6
import inspect
7
8
import os
import sys
9
10
from collections.abc import Callable
from typing import TypeVar, overload
11
from unittest.mock import patch
12
13

import torch
14
import torch.nn as nn
15
from packaging import version
16
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
17

18
import vllm.envs as envs
19
from vllm.compilation.counter import compilation_counter
20
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
21
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
22
from vllm.logger import init_logger
23
from vllm.sequence import IntermediateTensors
24
25
from vllm.utils import supports_dynamo
from vllm.utils.import_utils import resolve_obj_by_qualname
26

27
28
from .monitor import start_monitoring_torch_compile

29
logger = init_logger(__name__)
30

31
32
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

33
34
35
_T = TypeVar("_T", bound=type[nn.Module])


36
37
38
39
40
41
42
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
43
    decorator is applied to.
44
45
46

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


def _should_ignore_torch_compile(cls) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


63
64
65
@overload
def support_torch_compile(
    *,
66
    enable_if: Callable[[VllmConfig], bool] | None = None,
67
) -> Callable[[_T], _T]: ...
68
69


70
71
72
@overload
def support_torch_compile(
    *,
73
    dynamic_arg_dims: dict[str, int | list[int]] | None,
74
) -> Callable[[_T], _T]: ...
75
76
77


@overload
78
def support_torch_compile(cls: _T) -> _T: ...
79

80
81

def support_torch_compile(
82
    cls: _T | None = None,
83
    *,
84
85
86
    dynamic_arg_dims: dict[str, int | list[int]] | None = None,
    enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T] | _T:
87
88
89
    """
    A decorator to add support for compiling the forward method of a class.

90
91
92
93
94
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
95
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
96
97
98
99
100
101
102
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
103
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
104
105
    ```

106
107
108
109
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

110
111
112
113
114
115
116
117
118
119
120
121
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
122

123
    - if it is a single integer (can be negative), the corresponding dimension
124
        of the argument will be marked as dynamic.
125
126
127
128
129
130
131
132
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
133
134
135
136
137

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
138
139
    """

140
    def cls_decorator_helper(cls: _T) -> _T:
141
142
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
143
        if not hasattr(cls, "forward"):
144
            raise TypeError("decorated class should have a forward method.")
145
        sig = inspect.signature(cls.forward)
146
147
148
149
150
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
151
                    torch.Tensor,
152
                    torch.Tensor | None,
153
                    IntermediateTensors,
154
                    IntermediateTensors | None,
155
156
157
                ]:
                    inferred_dynamic_arg_dims[k] = 0

158
159
160
161
162
            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )
163
164
165
166

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
167
168
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )
169
170

        for k in inferred_dynamic_arg_dims:
171
172
            if k not in sig.parameters:
                raise ValueError(
173
174
175
                    f"Argument {k} not found in the forward method of {cls}"
                )
        return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
176
177
178
179
180

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
181
182
183
184

    return cls_decorator_helper


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def _model_hash_key(fn) -> str:
    import vllm

    sha256_hash = hashlib.sha256()
    sha256_hash.update(vllm.__version__.encode())
    sha256_hash.update(fn.__qualname__.encode())
    sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
    return sha256_hash.hexdigest()


def _verify_source_unchanged(source_info, vllm_config) -> None:
    from .caching import _compute_code_hash, _compute_code_hash_with_content

    file_contents = {}
    for source in source_info.inlined_sources:
        module = sys.modules[source.module]
        file = inspect.getfile(module)
        vllm_config.compilation_config.traced_files.add(file)
        file_contents[file] = source.content
    expected_checksum = _compute_code_hash_with_content(file_contents)
    actual_checksum = _compute_code_hash(set(file_contents.keys()))
    if expected_checksum != actual_checksum:
        raise RuntimeError(
            "Source code has changed since the last compilation. Recompiling the model."
        )


212
213
def _support_torch_compile(
    cls: _T,
214
215
    dynamic_arg_dims: dict[str, int | list[int]],
    enable_if: Callable[[VllmConfig], bool] | None = None,
216
) -> _T:
217
218
219
    """
    A decorator to add support for compiling the forward method of a class.
    """
220
221
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
222
223
224
225
226
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
227
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
228

229
    old_init = cls.__init__
230

231
232
    setattr(cls, IGNORE_COMPILE_KEY, False)

233
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
234
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
235
        self.vllm_config = vllm_config
236
        enable_compile = enable_if is None or enable_if(vllm_config)
237
        # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
238
        # will handle the compilation, so we don't need to do anything here.
239
        self.do_not_compile = (
240
241
            vllm_config.compilation_config.mode
            in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
242
243
244
245
            or not supports_dynamo()
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
246
247
        if self.do_not_compile:
            return
248

249
        compilation_counter.num_models_seen += 1
250
        TorchCompileWrapperWithCustomDispatcher.__init__(
251
            self, compilation_mode=vllm_config.compilation_config.mode
252
        )
253

254
    cls.__init__ = __init__
255

256
    def __call__(self, *args, **kwargs):
257
258
259
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
260
        if self.do_not_compile or torch.compiler.is_compiling():
261
            return self.forward(*args, **kwargs)
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        if getattr(self, "aot_compiled_fn", None) is not None:
            return self.aot_compiled_fn(self, *args, **kwargs)

        cache_dir = None
        aot_compilation_path = None
        if envs.VLLM_USE_AOT_COMPILE:
            """
            When using torch.compile in AOT mode, we store the cache artifacts
            under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
            contains all of the factors except for the source files being
            traced through, because we don't actually know which source files
            to check at this point (before dynamo runs).
            On loading we will actually look at the source files being traced
            through. If any source file have changed (compared with the
            serialized backend artifacts), then we need to generate a new AOT
            compile artifact from scratch.
            """
            from .caching import compilation_config_hash_factors

            factors: list[str] = compilation_config_hash_factors(self.vllm_config)

            factors.append(_model_hash_key(self.forward))
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()

            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT,
                "torch_aot_compile",
                hash_key,
            )

            rank = self.vllm_config.parallel_config.rank
            dp_rank = self.vllm_config.parallel_config.data_parallel_rank
            cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
            aot_compilation_path = os.path.join(cache_dir, "model")
            try:
                with (
                    set_current_vllm_config(self.vllm_config),
                    open(aot_compilation_path, "rb") as f,
                ):
                    start_monitoring_torch_compile(self.vllm_config)
                    loaded_fn = torch.compiler.load_compiled_function(f)
                _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
                self.aot_compiled_fn = loaded_fn
            except Exception as e:
                if os.path.exists(aot_compilation_path):
                    logger.warning(
                        "Cannot load aot compilation from path %s, error: %s",
                        aot_compilation_path,
                        str(e),
                    )
                if envs.VLLM_FORCE_AOT_LOAD:
                    raise e
            if getattr(self, "aot_compiled_fn", None) is not None:
                logger.info(
                    "Directly load AOT compilation from path %s", aot_compilation_path
                )
                return self.aot_compiled_fn(self, *args, **kwargs)

321
322
        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
323
324
325
326
327
328
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
329
                    dims = [dims] if isinstance(dims, int) else dims
330
                    if isinstance(arg, torch.Tensor):
331
                        # In case dims is specified with negative indexing
332
                        dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
333
334
335
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
336
337
                            # In case dims is specified with negative indexing
                            dims = [
338
                                tensor.ndim + dim if dim < 0 else dim for dim in dims
339
                            ]
340
341
342
343
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
344
345
                            f" {dims} for argument {k} with type {type(arg)}."
                        )
346
            # here, it is the starting point of the `torch.compile` process
347
            start_monitoring_torch_compile(self.vllm_config)
348
            logger.debug("Start compiling function %s", self.original_code_object)
349
350
351
352
353

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
354
355
356
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
357
            torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
358
359
360
361
362
363
364

            # collect all relevant files traced by Dynamo,
            # so that the compilation cache can trigger re-compilation
            # properly when any of these files change.

            # 1. the file containing the top-level forward function
            self.vllm_config.compilation_config.traced_files.add(
365
366
                self.original_code_object.co_filename
            )
367
368

            # 2. every time Dynamo sees a function call, it will inline
369
            # the function by calling InliningInstructionTranslator.inline_call_
370
371
            # we hijack this function to know all the functions called
            # during Dynamo tracing, and their corresponding files
372
            inline_call = InliningInstructionTranslator.inline_call_
373

374
375
            def patched_inline_call(self_):
                code = self_.f_code
376
                self.vllm_config.compilation_config.traced_files.add(code.co_filename)
377
                return inline_call(self_)
378

379
380
381
382
383
384
385
            # Disable the C++ compilation of symbolic shape guards. C++-fication
            # of symbolic shape guards can improve guard overhead. But, since
            # vllm skip guards anyways, setting this flag to False can improve
            # compile time.
            dynamo_config_patches = {}
            try:
                _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
386
                dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
387
388
389
            except AttributeError:
                # Note: this config is not available in torch 2.6, we can skip
                # if the config doesn't exist
390
391
392
393
                logger.debug("enable_cpp_symbolic_shape_guards config not available")

            with (
                patch.object(
394
                    InliningInstructionTranslator, "inline_call_", patched_inline_call
395
396
397
398
399
                ),
                torch._dynamo.config.patch(**dynamo_config_patches),
                maybe_use_cudagraph_partition_wrapper(self.vllm_config),
                _torch27_patch_tensor_subclasses(),
            ):
400
401
402
403
404
405
406
407
408
                if envs.VLLM_USE_AOT_COMPILE:
                    self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
                    output = self.aot_compiled_fn(self, *args, **kwargs)
                    assert aot_compilation_path is not None
                    assert cache_dir is not None
                    os.makedirs(cache_dir, exist_ok=True)
                    self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
                else:
                    output = self.compiled_callable(*args, **kwargs)
409
            return output
410
411
412
413
414

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
415
            model_output = self.forward(*args, **kwargs)
416
417
            return model_output

418
    cls.__call__ = __call__
419
    return cls
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
    whole `fx.Graph` before Inductor lowering and and the piecewise
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
437
438
439
440
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
441
442
443
444
445
446
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
447
448
            current_platform.get_static_graph_wrapper_cls()
        )
449

450
        def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
451
452
453
454
455
456
457
458
459
460
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
461
462
                ),
            )
463
464

        torch._inductor.utils.set_customized_partition_wrappers(
465
466
            customized_cudagraph_wrapper
        )
467
468
469

    yield

470
471
472
473
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
474
        torch._inductor.utils.set_customized_partition_wrappers(None)
475
476
477
478
479
480
481
482
483
484


@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
    """
    Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
    using torch 2.7.0. This enables using weight_loader_v2 and the use of
    `BasevLLMParameters` without having to replace them with regular tensors
    before `torch.compile`-time.
    """
485
486
487
488
489
490
    from vllm.model_executor.parameter import (
        BasevLLMParameter,
        ModelWeightParameter,
        RowvLLMParameter,
        _ColumnvLLMParameter,
    )
491
492
493
494

    def return_false(*args, **kwargs):
        return False

495
    if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
496
497
498
        yield
        return

499
500
501
502
503
504
505
506
507
508
509
510
511
512
    with (
        torch._dynamo.config.patch(
            "traceable_tensor_subclasses",
            [
                BasevLLMParameter,
                ModelWeightParameter,
                _ColumnvLLMParameter,
                RowvLLMParameter,
            ],
        ),
        patch(
            "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
        ),
    ):
513
        yield