inductor_pass.py 4.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from __future__ import annotations

6
import functools
7
8
import hashlib
import inspect
9
import json
10
import types
11
from collections.abc import Callable, Generator
12
from contextlib import contextmanager
13
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
14
15

import torch
16
from torch import fx
17
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
18

19
20
21
if TYPE_CHECKING:
    from vllm.config.utils import Range

22
from torch._inductor.custom_graph_pass import CustomGraphPass
23

24
_pass_context = None
25
26
P = ParamSpec("P")
R = TypeVar("R")
27
28
29


class PassContext:
30
31
    def __init__(self, compile_range: Range):
        self.compile_range: Range = compile_range
32
33
34
35
36
37
38
39
40


def get_pass_context() -> PassContext:
    """Get the current pass context."""
    assert _pass_context is not None
    return _pass_context


@contextmanager
41
def pass_context(compile_range: Range) -> Generator[None, None, None]:
42
43
44
45
46
    """A context manager that stores the current pass context,
    usually it is a list of sizes to specialize.
    """
    global _pass_context
    prev_context = _pass_context
47
    _pass_context = PassContext(compile_range)
48
49
50
51
52
    try:
        yield
    finally:
        _pass_context = prev_context

53

54
55
56
57
58
59
60
61
62
@functools.cache
def _hash_source_cached(*srcs: str | type | types.FunctionType) -> str:
    hasher = hashlib.sha256()
    for src in srcs:
        src_str = src if isinstance(src, str) else inspect.getsource(src)
        hasher.update(src_str.encode("utf-8"))
    return hasher.hexdigest()


63
class InductorPass(CustomGraphPass):  # type: ignore[misc]
64
    """
65
66
    A custom graph pass that uses a hash of its source as the UUID.
    This is defined as a convenience and should work in most cases.
67
    """
68

69
    def uuid(self) -> str:
70
71
72
73
74
75
76
77
78
        """
        Provide a unique identifier for the pass, used in Inductor code cache.
        This should depend on the pass implementation, so that changes to the
        pass result in recompilation.
        By default, the object source is hashed.
        """
        return InductorPass.hash_source(self)

    @staticmethod
79
    def hash_source(*srcs: str | Any) -> str:
80
81
82
83
        """
        Utility method to hash the sources of functions or objects.
        :param srcs: strings or objects to add to the hash.
        Objects and functions have their source inspected.
84
85
        Results are cached by resolved types to avoid repeated
        inspect.getsource() calls.
86
87
        :return:
        """
88
89
90
91
92
93
        # Resolve instances to their class for a hashable cache key.
        cache_key = tuple(
            src if isinstance(src, (str, type, types.FunctionType)) else src.__class__
            for src in srcs
        )
        return _hash_source_cached(*cache_key)
94
95

    @staticmethod
96
    def hash_dict(dict_: dict[Any, Any]) -> str:
97
98
99
100
101
102
        """
        Utility method to hash a dictionary, can alternatively be used for uuid.
        :return: A sha256 hash of the json rep of the dictionary.
        """
        encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
        return hashlib.sha256(encoded).hexdigest()
103

104
    def is_applicable_for_range(self, compile_range: Range) -> bool:
105
106
        return True

107
108
109
110
111
112
113

class CallableInductorPass(InductorPass):
    """
    This class is a wrapper for a callable that automatically provides an
    implementation of the UUID.
    """

114
115
116
    def __init__(
        self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
    ) -> None:
117
        self.callable = callable
118
        self._uuid = self.hash_source(callable) if uuid is None else uuid
119

120
    def __call__(self, graph: torch.fx.Graph) -> None:
121
122
123
124
        self.callable(graph)

    def uuid(self) -> Any:
        return self._uuid
125
126


127
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
128
129
130
131
132
133
    """
    Applies a FakeTensorMode context. This is useful when you don't want to
    create or run things with real tensors.
    """

    @functools.wraps(fn)
134
    def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
135
        with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
136
137
138
139
140
            result = fn(*args, **kwargs)

        return result

    return fn_new