inductor_pass.py 2.56 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import hashlib
import inspect
5
import json
6
import types
7
from typing import Any, Callable, Dict, Optional, Union
8
9

import torch
10
from torch import fx
11

12
13
14
from vllm.utils import is_torch_equal_or_newer

if is_torch_equal_or_newer("2.6"):
15
16
17
18
19
    from torch._inductor.custom_graph_pass import CustomGraphPass
else:
    # CustomGraphPass is not present in 2.5 or lower, import our version
    from .torch25_custom_graph_pass import (  # noqa: yapf
        Torch25CustomGraphPass as CustomGraphPass)
20

21
22

class InductorPass(CustomGraphPass):
23
    """
24
25
    A custom graph pass that uses a hash of its source as the UUID.
    This is defined as a convenience and should work in most cases.
26
    """
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def uuid(self) -> Any:
        """
        Provide a unique identifier for the pass, used in Inductor code cache.
        This should depend on the pass implementation, so that changes to the
        pass result in recompilation.
        By default, the object source is hashed.
        """
        return InductorPass.hash_source(self)

    @staticmethod
    def hash_source(*srcs: Union[str, Any]):
        """
        Utility method to hash the sources of functions or objects.
        :param srcs: strings or objects to add to the hash.
        Objects and functions have their source inspected.
        :return:
        """
        hasher = hashlib.sha256()
        for src in srcs:
            if isinstance(src, str):
                src_str = src
            elif isinstance(src, types.FunctionType):
                src_str = inspect.getsource(src)
            else:
                src_str = inspect.getsource(src.__class__)
            hasher.update(src_str.encode("utf-8"))
54
55
56
57
58
59
60
61
62
63
        return hasher.hexdigest()

    @staticmethod
    def hash_dict(dict_: Dict[Any, Any]):
        """
        Utility method to hash a dictionary, can alternatively be used for uuid.
        :return: A sha256 hash of the json rep of the dictionary.
        """
        encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
        return hashlib.sha256(encoded).hexdigest()
64
65
66
67
68
69
70
71
72
73
74
75


class CallableInductorPass(InductorPass):
    """
    This class is a wrapper for a callable that automatically provides an
    implementation of the UUID.
    """

    def __init__(self,
                 callable: Callable[[fx.Graph], None],
                 uuid: Optional[Any] = None):
        self.callable = callable
76
        self._uuid = self.hash_source(callable) if uuid is None else uuid
77
78
79
80
81
82

    def __call__(self, graph: torch.fx.Graph):
        self.callable(graph)

    def uuid(self) -> Any:
        return self._uuid