inductor_pass.py 3.89 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
6
import hashlib
import inspect
7
import json
8
import types
9
from collections.abc import Callable
10
from contextlib import contextmanager
11
from typing import Any
12
13

import torch
14
from torch import fx
15
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
16

17
18
19
from vllm.utils import is_torch_equal_or_newer

if is_torch_equal_or_newer("2.6"):
20
21
22
    from torch._inductor.custom_graph_pass import CustomGraphPass
else:
    # CustomGraphPass is not present in 2.5 or lower, import our version
23
    from .torch25_custom_graph_pass import (
24
25
        Torch25CustomGraphPass as CustomGraphPass,
    )
26

27
28
29
30
_pass_context = None


class PassContext:
31
    def __init__(self, runtime_shape: int | None):
32
33
34
35
36
37
38
39
40
41
        self.runtime_shape = runtime_shape


def get_pass_context() -> PassContext:
    """Get the current pass context."""
    assert _pass_context is not None
    return _pass_context


@contextmanager
42
def pass_context(runtime_shape: int | None):
43
44
45
46
47
48
49
50
51
52
53
    """A context manager that stores the current pass context,
    usually it is a list of sizes to specialize.
    """
    global _pass_context
    prev_context = _pass_context
    _pass_context = PassContext(runtime_shape)
    try:
        yield
    finally:
        _pass_context = prev_context

54
55

class InductorPass(CustomGraphPass):
56
    """
57
58
    A custom graph pass that uses a hash of its source as the UUID.
    This is defined as a convenience and should work in most cases.
59
    """
60

61
62
63
64
65
66
67
68
69
70
    def uuid(self) -> Any:
        """
        Provide a unique identifier for the pass, used in Inductor code cache.
        This should depend on the pass implementation, so that changes to the
        pass result in recompilation.
        By default, the object source is hashed.
        """
        return InductorPass.hash_source(self)

    @staticmethod
71
    def hash_source(*srcs: str | Any):
72
73
74
75
76
77
78
79
80
81
        """
        Utility method to hash the sources of functions or objects.
        :param srcs: strings or objects to add to the hash.
        Objects and functions have their source inspected.
        :return:
        """
        hasher = hashlib.sha256()
        for src in srcs:
            if isinstance(src, str):
                src_str = src
82
            elif isinstance(src, (types.FunctionType, type)):
83
84
                src_str = inspect.getsource(src)
            else:
85
                # object instance
86
87
                src_str = inspect.getsource(src.__class__)
            hasher.update(src_str.encode("utf-8"))
88
89
90
        return hasher.hexdigest()

    @staticmethod
91
    def hash_dict(dict_: dict[Any, Any]):
92
93
94
95
96
97
        """
        Utility method to hash a dictionary, can alternatively be used for uuid.
        :return: A sha256 hash of the json rep of the dictionary.
        """
        encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
        return hashlib.sha256(encoded).hexdigest()
98

99
    def is_applicable_for_shape(self, shape: int | None):
100
101
        return True

102
103
104
105
106
107
108

class CallableInductorPass(InductorPass):
    """
    This class is a wrapper for a callable that automatically provides an
    implementation of the UUID.
    """

109
    def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None):
110
        self.callable = callable
111
        self._uuid = self.hash_source(callable) if uuid is None else uuid
112
113
114
115
116
117

    def __call__(self, graph: torch.fx.Graph):
        self.callable(graph)

    def uuid(self) -> Any:
        return self._uuid
118
119
120
121
122
123
124
125
126
127


def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
    """
    Applies a FakeTensorMode context. This is useful when you don't want to
    create or run things with real tensors.
    """

    @functools.wraps(fn)
    def fn_new(*args, **kwargs) -> Any:
128
        with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
129
130
131
132
133
            result = fn(*args, **kwargs)

        return result

    return fn_new