pass_manager.py 6.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
6

7
8
from torch import fx as fx

9
from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
12
from vllm.config import VllmConfig, set_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.platforms import current_platform
15
from vllm.utils.system_utils import set_env_var
16

17
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
18

19
if rocm_aiter_ops.is_enabled():
20
    from .fusion.rocm_aiter_fusion import (
21
        RocmAiterRMSNormQuantFusionPass,
22
        RocmAiterSiluMulFp8GroupQuantFusionPass,
23
        RocmAiterTritonAddRMSNormPadFusionPass,
24
25
    )

26
if current_platform.is_cuda_alike():
27
    from .fusion.act_quant_fusion import ActivationQuantFusionPass
28
    from .fusion.attn_quant_fusion import AttnQuantFusionPass
29
30
    from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
    from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
31
    from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
32
    from .fusion.sequence_parallelism import SequenceParallelismPass
33
    from .utility.scatter_split_replace import ScatterSplitReplacementPass
34
    from .utility.split_coalescing import SplitCoalescingPass
35

36
if current_platform.is_cuda():
37
38
    from .fusion.allreduce_rms_fusion import AllReduceFusionPass
    from .fusion.collective_fusion import AsyncTPPass
39

40
41
42
43
44
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
45
46
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
47
48
49

logger = init_logger(__name__)

50
51
P = ParamSpec("P")
R = TypeVar("R")
52

53
54

def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
55
56
57
58
59
60
61
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
62
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
63
64
65
66
67
68
69
70
71
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


72
class PostGradPassManager(CustomGraphPass):  # type: ignore[misc]
73
74
75
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
76
77
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
78
79
80

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
81
    2. default passes (NoopEliminationPass, FusionPass)
82
83
84
85
86
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

87
    def __init__(self) -> None:
88
        self.passes: list[InductorPass] = []
89

90
    @with_pattern_match_debug
91
    def __call__(self, graph: fx.Graph) -> None:
92
93
        VllmInductorPass.dump_prefix = 0  # reset dump index

94
        compile_range = get_pass_context().compile_range
95
        for pass_ in self.passes:
96
            if pass_.is_applicable_for_range(compile_range):
97
                pass_(graph)
98
                VllmInductorPass.dump_prefix += 1
99
            else:
100
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
101
102
103
104
105

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
106
107
108

        # always run fix_functionalization last
        self.fix_functionalization(graph)
109
        VllmInductorPass.dump_prefix = None  # Cleanup index
110

111
112
        VllmPatternMatcherPass.log_match_summary()

113
    def configure(self, config: VllmConfig) -> None:
114
        self.pass_config = config.compilation_config.pass_config
115

116
117
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
118
            if self.pass_config.eliminate_noops:
119
                self.passes += [NoOpEliminationPass(config)]
120

121
            if self.pass_config.enable_sp:
122
                self.passes += [SequenceParallelismPass(config)]
123
                if self.pass_config.fuse_gemm_comms:
124
                    self.passes += [AsyncTPPass(config)]
125

126
            if self.pass_config.fuse_allreduce_rms:
127
                self.passes += [AllReduceFusionPass(config)]
128

129
            if self.pass_config.fuse_norm_quant:
130
                self.passes += [RMSNormQuantFusionPass(config)]
131
                if rocm_aiter_ops.is_enabled():
132
                    self.passes += [
133
                        RocmAiterRMSNormQuantFusionPass(config),
134
                    ]
135
            if self.pass_config.fuse_act_quant:
136
                self.passes += [ActivationQuantFusionPass(config)]
137
138
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
139

140
141
142
            if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
                self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]

143
144
145
146
147
            if self.pass_config.fuse_rope_kvcache:
                self.passes += [SplitCoalescingPass(config)]
                self.passes += [ScatterSplitReplacementPass(config)]
                self.passes += [RopeKVCacheFusionPass(config)]

148
            if self.pass_config.fuse_attn_quant:
149
                self.passes += [AttnQuantFusionPass(config)]
150

151
            if self.pass_config.enable_qk_norm_rope_fusion:
152
                self.passes += [SplitCoalescingPass(config)]
153
154
                self.passes += [QKNormRoPEFusionPass(config)]

155
156
157
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
158

159
    def add(self, pass_: InductorPass) -> None:
160
161
162
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

163
    def uuid(self) -> str:
164
        """
165
166
167
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
168
        """
169
170
171
        passes = []

        state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
172
        for pass_ in self.passes:
173
174
            passes.append(pass_.uuid())
        passes.append(self.fix_functionalization.uuid())
175

176
177
178
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)
179
        state["passes"] = passes
180
        return InductorPass.hash_dict(state)