pass_manager.py 7.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
6

7
8
from torch import fx as fx

9
from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
12
from vllm.config import VllmConfig, set_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.platforms import current_platform
15
from vllm.utils.system_utils import set_env_var
16

17
from .ir.lowering_pass import VllmIRLoweringPass
18
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
19

20
if rocm_aiter_ops.is_enabled():
21
    from .fusion.rocm_aiter_fusion import (
22
        RocmAiterRMSNormQuantFusionPass,
23
        RocmAiterSiluMulFp8GroupQuantFusionPass,
24
        RocmAiterTritonAddRMSNormPadFusionPass,
25
26
    )

27
if current_platform.is_cuda_alike():
28
    from .fusion.act_quant_fusion import ActivationQuantFusionPass
29
    from .fusion.attn_quant_fusion import AttnQuantFusionPass
30
    from .fusion.mla_attn_quant_fusion import MLAAttnQuantFusionPass
31
32
    from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
    from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
33
    from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
34
    from .fusion.sequence_parallelism import SequenceParallelismPass
35
    from .utility.scatter_split_replace import ScatterSplitReplacementPass
36
    from .utility.split_coalescing import SplitCoalescingPass
37

38
if current_platform.is_cuda():
39
40
    from .fusion.allreduce_rms_fusion import AllReduceFusionPass
    from .fusion.collective_fusion import AsyncTPPass
41

42
43
44
45
46
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
47
48
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
49
50
51

logger = init_logger(__name__)

52
53
P = ParamSpec("P")
R = TypeVar("R")
54

55
56

def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
57
58
59
60
61
62
63
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
64
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
65
66
67
68
69
70
71
72
73
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


74
class PostGradPassManager(CustomGraphPass):  # type: ignore[misc]
75
76
77
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
78
79
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
80
81
82

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
83
    2. default passes (NoopEliminationPass, FusionPass)
84
85
86
87
88
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

89
    def __init__(self) -> None:
90
        self.passes: list[InductorPass] = []
91

92
    @with_pattern_match_debug
93
    def __call__(self, graph: fx.Graph) -> None:
94
95
        VllmInductorPass.dump_prefix = 0  # reset dump index

96
        compile_range = get_pass_context().compile_range
97
        for pass_ in self.passes:
98
            if pass_.is_applicable_for_range(compile_range):
99
                pass_(graph)
100
                VllmInductorPass.dump_prefix += 1
101
            else:
102
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
103

104
105
106
107
108
109
110
111
112
113
114
        # perform the first post-cleanup before IR lowering to clean up fusion artifacts
        # and make sure no dead IR ops are lowered.
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1

        # lowering before cleanup so DCE can clean up lowered ops.
        # DCE handles mutating ops correctly as well.
        self.ir_lowering(graph)
        VllmInductorPass.dump_prefix += 1

        # clean up after lowering again
115
116
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
117
118
119

        # always run fix_functionalization last
        self.fix_functionalization(graph)
120
        VllmInductorPass.dump_prefix = None  # Cleanup index
121

122
123
        VllmPatternMatcherPass.log_match_summary()

124
    def configure(self, config: VllmConfig) -> None:
125
        self.pass_config = config.compilation_config.pass_config
126

127
128
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
129
            if self.pass_config.eliminate_noops:
130
                self.passes += [NoOpEliminationPass(config)]
131

132
            if self.pass_config.enable_sp:
133
                self.passes += [SequenceParallelismPass(config)]
134
                if self.pass_config.fuse_gemm_comms:
135
                    self.passes += [AsyncTPPass(config)]
136

137
            if self.pass_config.fuse_allreduce_rms:
138
                self.passes += [AllReduceFusionPass(config)]
139

140
            if self.pass_config.fuse_norm_quant:
141
                self.passes += [RMSNormQuantFusionPass(config)]
142
                if rocm_aiter_ops.is_enabled():
143
                    self.passes += [
144
                        RocmAiterRMSNormQuantFusionPass(config),
145
                    ]
146
            if self.pass_config.fuse_act_quant:
147
                self.passes += [ActivationQuantFusionPass(config)]
148
149
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
150

151
152
153
            if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
                self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]

154
155
156
157
158
            if self.pass_config.fuse_rope_kvcache:
                self.passes += [SplitCoalescingPass(config)]
                self.passes += [ScatterSplitReplacementPass(config)]
                self.passes += [RopeKVCacheFusionPass(config)]

159
            if self.pass_config.fuse_attn_quant:
160
                self.passes += [AttnQuantFusionPass(config)]
161
                self.passes += [MLAAttnQuantFusionPass(config)]
162

163
            if self.pass_config.enable_qk_norm_rope_fusion:
164
                self.passes += [SplitCoalescingPass(config)]
165
166
                self.passes += [QKNormRoPEFusionPass(config)]

167
            self.ir_lowering = VllmIRLoweringPass(config)
168
169
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
170

171
    def add(self, pass_: InductorPass) -> None:
172
173
174
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

175
    def uuid(self) -> str:
176
        """
177
178
179
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
180
        """
181
182
183
        passes = []

        state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
184
        for pass_ in self.passes:
185
            passes.append(pass_.uuid())
186
187
188
189

        passes.append(self.post_cleanup.uuid())
        passes.append(self.ir_lowering.uuid())
        passes.append(self.post_cleanup.uuid())
190
        passes.append(self.fix_functionalization.uuid())
191

192
193
194
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)
195
        state["passes"] = passes
196
        return InductorPass.hash_dict(state)