pass_manager.py 2.99 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Any, Dict, List
4

5
import torch
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch import fx as fx

from vllm.config import CompilationConfig
from vllm.logger import init_logger

from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .reshapes import RedundantReshapesPass

logger = init_logger(__name__)


19
20
21
22
23
24
25
26
27
28
29
class PlaceHolder:
    pass


if torch.__version__ < "2.6":
    Parent = PlaceHolder  # type: ignore
else:
    Parent = torch._inductor.custom_graph_pass.CustomGraphPass  # type: ignore


class PostGradPassManager(Parent):
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
    It also supports pickling, which is used by the Inductor code cache.
    TODO(torch==2.6), use CustomGraphPass
    (torch._inductor.custom_graph_pass.CustomGraphPass)

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
    2. default passes (RedundantReshapesPass, FusionPass)
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
        self.passes: List[InductorPass] = []

    def __call__(self, graph: fx.Graph):
        for pass_ in self.passes:
            pass_(graph)

        # always run fix_functionalization last
        self.fix_functionalization(graph)

    def configure(self, pass_config: CompilationConfig.PassConfig):
        self.pass_config = pass_config
        if pass_config.enable_reshape:
            self.passes += [RedundantReshapesPass(pass_config)]

        if pass_config.enable_fusion:
            self.passes += [FusionPass.instance(pass_config)]

        self.fix_functionalization = FixFunctionalizationPass(pass_config)

    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

69
70
71
    def uuid(self):
        return self.__getstate__()

72
    def __getstate__(self) -> Dict[str, List[Any]]:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        """
        Custom pickling for the pass manager, as some passes cannot be pickled.
        Pickling occurs because the pass manager is set as the value of
        `config["post_grad_custom_post_pass"]` in the Inductor config.
        The config is pickled to act as a key in the Inductor code cache.
        Any other passes in the config are pickled as well.

        TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
        """
        state = {"pass_config": self.pass_config.uuid(), "passes": []}
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
        return state

    def __setstate__(self, state):
        """
        Do not allow unpickling of the pass manager.
        If this is needed in the future, it should properly pickle the passes.
        """
        raise ValueError("Cannot unpickle PostGradPassManager")