pass_manager.py 6.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4

5
6
from torch import fx as fx

7
from vllm import envs
8
from vllm.config import VllmConfig, set_current_vllm_config
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
12
13
14
from vllm.utils import set_env_var

from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
15

16
if current_platform.is_cuda_alike():
17
    from .activation_quant_fusion import ActivationQuantFusionPass
18
    from .fusion import RMSNormQuantFusionPass
19
    from .fusion_attn import AttnFusionPass
20

21
22
23
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

24
from .fix_functionalization import FixFunctionalizationPass
25
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
26
from .noop_elimination import NoOpEliminationPass
27
from .sequence_parallelism import SequenceParallelismPass
28
29
30
31

logger = init_logger(__name__)


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def with_pattern_match_debug(fn):
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


50
class PostGradPassManager(CustomGraphPass):
51
52
53
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
54
55
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
56
57
58

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
59
    2. default passes (NoopEliminationPass, FusionPass)
60
61
62
63
64
65
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
66
        self.passes: list[InductorPass] = []
67

68
    @with_pattern_match_debug
69
    def __call__(self, graph: fx.Graph):
70
71
        VllmInductorPass.dump_prefix = 0  # reset dump index

72
        shape = get_pass_context().runtime_shape
73
        for pass_ in self.passes:
74
            if pass_.is_applicable(shape):
75
                pass_(graph)
76
                VllmInductorPass.dump_prefix += 1
77
78
            else:
                logger.debug("Skipping %s with shape %s", pass_, shape)
79
80
81
82
83

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
84
85
86

        # always run fix_functionalization last
        self.fix_functionalization(graph)
87
        VllmInductorPass.dump_prefix = None  # Cleanup index
88

89
90
    def configure(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
91

92
93
94
95
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
            if self.pass_config.enable_noop:
                self.passes += [NoOpEliminationPass(config)]
96

97
98
99
100
            if self.pass_config.enable_sequence_parallelism:
                self.passes += [SequenceParallelismPass(config)]
                if self.pass_config.enable_async_tp:
                    self.passes += [AsyncTPPass(config)]
101

102
103
            if self.pass_config.enable_fi_allreduce_fusion:
                self.passes += [AllReduceFusionPass(config)]
104

105
106
107
            if self.pass_config.enable_fusion:
                self.passes += [RMSNormQuantFusionPass(config)]
                self.passes += [ActivationQuantFusionPass(config)]
108

109
110
111
112
113
114
            if self.pass_config.enable_attn_fusion:
                self.passes += [AttnFusionPass(config)]

            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        # [HACK: Bug with Inductor graph partition and torch.compile cache]
        # In PyTorch 2.9, torch.compile has a bug where the graph
        # partition is not taken into account during caching.
        # Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
        # Inductor graph partition, and VLLM_COMPILE implies there
        # is a PostGradPassManager, we put the list of operators to graph
        # partition into the PostGradPassManager's uuid (which
        # then gets incorporated into Inductor's FX graph cache key).
        # Remove this hack whenever torch.compile fixes it.

        # This is the list of operators that vLLM asks Inductor to split.
        self.inductor_splitting_ops = []
        if (
            config.compilation_config.use_inductor_graph_partition
            and config.compilation_config.splitting_ops is not None
        ):
            # Sort them so we're not dependent on the ordering.
            self.inductor_splitting_ops = sorted(
                config.compilation_config.splitting_ops
            )

137
138
139
140
    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

141
    def uuid(self):
142
        """
143
144
145
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
146
        """
147
148
149
150
151
        state = {
            "pass_config": self.pass_config.uuid(),
            "passes": [],
            "inductor_splitting_ops": [],
        }
152
153
154
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
155
156
157
158

        # See [HACK: Bug with Inductor graph partition and torch.compile cache]
        state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)

159
        return InductorPass.hash_dict(state)