pass_manager.py 7.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
6

7
8
from torch import fx as fx

9
from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
12
from vllm.config import VllmConfig, set_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.platforms import current_platform
15
from vllm.utils.system_utils import set_env_var
16

17
from .ir.lowering_pass import VllmIRLoweringPass
18
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
19

20
if rocm_aiter_ops.is_enabled():
21
    from .fusion.rocm_aiter_fusion import (
22
        MLADualRMSNormFusionPass,
23
        RocmAiterRMSNormQuantFusionPass,
24
        RocmAiterSiluMulFp8GroupQuantFusionPass,
25
        RocmAiterTritonAddRMSNormPadFusionPass,
26
27
    )

28
if current_platform.is_cuda_alike():
29
    from .fusion.act_quant_fusion import ActivationQuantFusionPass
30
    from .fusion.attn_quant_fusion import AttnQuantFusionPass
31
    from .fusion.mla_attn_quant_fusion import MLAAttnQuantFusionPass
32
33
    from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
    from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
34
    from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
35
    from .fusion.sequence_parallelism import SequenceParallelismPass
36
    from .utility.scatter_split_replace import ScatterSplitReplacementPass
37
    from .utility.split_coalescing import SplitCoalescingPass
38

39
if current_platform.is_cuda():
40
41
    from .fusion.allreduce_rms_fusion import AllReduceFusionPass
    from .fusion.collective_fusion import AsyncTPPass
42
    from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass
43

44
45
46
47
48
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
49
50
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
51
52
53

logger = init_logger(__name__)

54
55
P = ParamSpec("P")
R = TypeVar("R")
56

57
58

def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
59
60
61
62
63
64
65
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
66
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
67
68
69
70
71
72
73
74
75
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


76
class PostGradPassManager(CustomGraphPass):  # type: ignore[misc]
77
78
79
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
80
81
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
82
83
84

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
85
    2. default passes (NoopEliminationPass, FusionPass)
86
87
88
89
90
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

91
    def __init__(self) -> None:
92
        self.passes: list[InductorPass] = []
93

94
    @with_pattern_match_debug
95
    def __call__(self, graph: fx.Graph) -> None:
96
97
        VllmInductorPass.dump_prefix = 0  # reset dump index

98
        compile_range = get_pass_context().compile_range
99
        for pass_ in self.passes:
100
            if pass_.is_applicable_for_range(compile_range):
101
                pass_(graph)
102
                VllmInductorPass.dump_prefix += 1
103
            else:
104
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
105

106
107
108
109
110
111
112
113
114
115
116
        # perform the first post-cleanup before IR lowering to clean up fusion artifacts
        # and make sure no dead IR ops are lowered.
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1

        # lowering before cleanup so DCE can clean up lowered ops.
        # DCE handles mutating ops correctly as well.
        self.ir_lowering(graph)
        VllmInductorPass.dump_prefix += 1

        # clean up after lowering again
117
118
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
119
120
121

        # always run fix_functionalization last
        self.fix_functionalization(graph)
122
        VllmInductorPass.dump_prefix = None  # Cleanup index
123

124
125
        VllmPatternMatcherPass.log_match_summary()

126
    def configure(self, config: VllmConfig) -> None:
127
        self.pass_config = config.compilation_config.pass_config
128

129
130
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
131
            if self.pass_config.eliminate_noops:
132
                self.passes += [NoOpEliminationPass(config)]
133

134
            if self.pass_config.enable_sp:
135
                self.passes += [SequenceParallelismPass(config)]
136
                if self.pass_config.fuse_gemm_comms:
137
                    self.passes += [AsyncTPPass(config)]
138

139
            if self.pass_config.fuse_allreduce_rms:
140
                self.passes += [AllReduceFusionPass(config)]
141

142
143
144
            if self.pass_config.fuse_minimax_qk_norm:
                self.passes += [MiniMaxQKNormPass(config)]

145
            if self.pass_config.fuse_norm_quant:
146
                self.passes += [RMSNormQuantFusionPass(config)]
147
                if rocm_aiter_ops.is_enabled():
148
                    self.passes += [
149
                        RocmAiterRMSNormQuantFusionPass(config),
150
                    ]
151
            if self.pass_config.fuse_act_quant:
152
                self.passes += [ActivationQuantFusionPass(config)]
153
154
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
155

156
157
158
            if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
                self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]

159
160
161
            if self.pass_config.fuse_mla_dual_rms_norm and rocm_aiter_ops.is_enabled():
                self.passes += [MLADualRMSNormFusionPass(config)]

162
163
164
165
166
            if self.pass_config.fuse_rope_kvcache:
                self.passes += [SplitCoalescingPass(config)]
                self.passes += [ScatterSplitReplacementPass(config)]
                self.passes += [RopeKVCacheFusionPass(config)]

167
            if self.pass_config.fuse_attn_quant:
168
                self.passes += [AttnQuantFusionPass(config)]
169
                self.passes += [MLAAttnQuantFusionPass(config)]
170

171
            if self.pass_config.enable_qk_norm_rope_fusion:
172
                self.passes += [SplitCoalescingPass(config)]
173
174
                self.passes += [QKNormRoPEFusionPass(config)]

175
            self.ir_lowering = VllmIRLoweringPass(config)
176
177
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
178

179
    def add(self, pass_: InductorPass) -> None:
180
181
182
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

183
    def uuid(self) -> str:
184
        """
185
186
187
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
188
        """
189
190
191
        passes = []

        state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
192
        for pass_ in self.passes:
193
            passes.append(pass_.uuid())
194
195
196
197

        passes.append(self.post_cleanup.uuid())
        passes.append(self.ir_lowering.uuid())
        passes.append(self.post_cleanup.uuid())
198
        passes.append(self.fix_functionalization.uuid())
199

200
201
202
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)
203
        state["passes"] = passes
204
        return InductorPass.hash_dict(state)