inductor_pass.py 4.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import functools
7
8
import hashlib
import inspect
9
import json
10
import types
11
from collections.abc import Callable, Generator
12
from contextlib import contextmanager
13
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
14
15

import torch
16
from torch import fx
17
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
18

19
from vllm.utils.torch_utils import is_torch_equal_or_newer
20

21
22
23
if TYPE_CHECKING:
    from vllm.config.utils import Range

24
if is_torch_equal_or_newer("2.6"):
25
26
27
    from torch._inductor.custom_graph_pass import CustomGraphPass
else:
    # CustomGraphPass is not present in 2.5 or lower, import our version
28
    from .torch25_custom_graph_pass import (
29
30
        Torch25CustomGraphPass as CustomGraphPass,
    )
31

32
33
34
# Re-export CustomGraphPass for external usage
__all__ = ["CustomGraphPass"]

35
_pass_context = None
36
37
P = ParamSpec("P")
R = TypeVar("R")
38
39
40


class PassContext:
41
42
    def __init__(self, compile_range: Range):
        self.compile_range: Range = compile_range
43
44
45
46
47
48
49
50
51


def get_pass_context() -> PassContext:
    """Get the current pass context."""
    assert _pass_context is not None
    return _pass_context


@contextmanager
52
def pass_context(compile_range: Range) -> Generator[None, None, None]:
53
54
55
56
57
    """A context manager that stores the current pass context,
    usually it is a list of sizes to specialize.
    """
    global _pass_context
    prev_context = _pass_context
58
    _pass_context = PassContext(compile_range)
59
60
61
62
63
    try:
        yield
    finally:
        _pass_context = prev_context

64

65
class InductorPass(CustomGraphPass):  # type: ignore[misc]
66
    """
67
68
    A custom graph pass that uses a hash of its source as the UUID.
    This is defined as a convenience and should work in most cases.
69
    """
70

71
    def uuid(self) -> str:
72
73
74
75
76
77
78
79
80
        """
        Provide a unique identifier for the pass, used in Inductor code cache.
        This should depend on the pass implementation, so that changes to the
        pass result in recompilation.
        By default, the object source is hashed.
        """
        return InductorPass.hash_source(self)

    @staticmethod
81
    def hash_source(*srcs: str | Any) -> str:
82
83
84
85
86
87
88
89
90
91
        """
        Utility method to hash the sources of functions or objects.
        :param srcs: strings or objects to add to the hash.
        Objects and functions have their source inspected.
        :return:
        """
        hasher = hashlib.sha256()
        for src in srcs:
            if isinstance(src, str):
                src_str = src
92
            elif isinstance(src, (types.FunctionType, type)):
93
94
                src_str = inspect.getsource(src)
            else:
95
                # object instance
96
97
                src_str = inspect.getsource(src.__class__)
            hasher.update(src_str.encode("utf-8"))
98
99
100
        return hasher.hexdigest()

    @staticmethod
101
    def hash_dict(dict_: dict[Any, Any]) -> str:
102
103
104
105
106
107
        """
        Utility method to hash a dictionary, can alternatively be used for uuid.
        :return: A sha256 hash of the json rep of the dictionary.
        """
        encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
        return hashlib.sha256(encoded).hexdigest()
108

109
    def is_applicable_for_range(self, compile_range: Range) -> bool:
110
111
        return True

112
113
114
115
116
117
118

class CallableInductorPass(InductorPass):
    """
    This class is a wrapper for a callable that automatically provides an
    implementation of the UUID.
    """

119
120
121
    def __init__(
        self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
    ) -> None:
122
        self.callable = callable
123
        self._uuid = self.hash_source(callable) if uuid is None else uuid
124

125
    def __call__(self, graph: torch.fx.Graph) -> None:
126
127
128
129
        self.callable(graph)

    def uuid(self) -> Any:
        return self._uuid
130
131


132
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
133
134
135
136
137
138
    """
    Applies a FakeTensorMode context. This is useful when you don't want to
    create or run things with real tensors.
    """

    @functools.wraps(fn)
139
    def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
140
        with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
141
142
143
144
145
            result = fn(*args, **kwargs)

        return result

    return fn_new