pass_manager.py 6.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
6

7
8
from torch import fx as fx

9
from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
12
from vllm.config import VllmConfig, set_current_vllm_config
13
from vllm.logger import init_logger
14
from vllm.platforms import current_platform
15
from vllm.utils.system_utils import set_env_var
16
17

from .vllm_inductor_pass import VllmInductorPass
18

19
if rocm_aiter_ops.is_enabled():
20
    from .fusion.rocm_aiter_fusion import (
21
        RocmAiterRMSNormQuantFusionPass,
22
        RocmAiterSiluMulFp8GroupQuantFusionPass,
23
        RocmAiterTritonAddRMSNormPadFusionPass,
24
25
    )

26
if current_platform.is_cuda_alike():
27
28
29
30
31
    from .fusion.act_quant_fusion import ActivationQuantFusionPass
    from .fusion.attn_quant_fusion import AttnFusionPass
    from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
    from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
    from .fusion.sequence_parallelism import SequenceParallelismPass
32
    from .utility.split_coalescing import SplitCoalescingPass
33

34
if current_platform.is_cuda():
35
36
    from .fusion.allreduce_rms_fusion import AllReduceFusionPass
    from .fusion.collective_fusion import AsyncTPPass
37

38
39
40
41
42
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
43
44
from .utility.fix_functionalization import FixFunctionalizationPass
from .utility.noop_elimination import NoOpEliminationPass
45
46
47

logger = init_logger(__name__)

48
49
P = ParamSpec("P")
R = TypeVar("R")
50

51
52

def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
53
54
55
56
57
58
59
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
60
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
61
62
63
64
65
66
67
68
69
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


70
class PostGradPassManager(CustomGraphPass):  # type: ignore[misc]
71
72
73
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
74
75
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
76
77
78

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
79
    2. default passes (NoopEliminationPass, FusionPass)
80
81
82
83
84
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

85
    def __init__(self) -> None:
86
        self.passes: list[InductorPass] = []
87

88
    @with_pattern_match_debug
89
    def __call__(self, graph: fx.Graph) -> None:
90
91
        VllmInductorPass.dump_prefix = 0  # reset dump index

92
        compile_range = get_pass_context().compile_range
93
        for pass_ in self.passes:
94
            if pass_.is_applicable_for_range(compile_range):
95
                pass_(graph)
96
                VllmInductorPass.dump_prefix += 1
97
            else:
98
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
99
100
101
102
103

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
104
105
106

        # always run fix_functionalization last
        self.fix_functionalization(graph)
107
        VllmInductorPass.dump_prefix = None  # Cleanup index
108

109
    def configure(self, config: VllmConfig) -> None:
110
        self.pass_config = config.compilation_config.pass_config
111

112
113
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
114
            if self.pass_config.eliminate_noops:
115
                self.passes += [NoOpEliminationPass(config)]
116

117
            if self.pass_config.enable_sp:
118
                self.passes += [SequenceParallelismPass(config)]
119
                if self.pass_config.fuse_gemm_comms:
120
                    self.passes += [AsyncTPPass(config)]
121

122
            if self.pass_config.fuse_allreduce_rms:
123
                self.passes += [AllReduceFusionPass(config)]
124

125
            if self.pass_config.fuse_norm_quant:
126
                self.passes += [RMSNormQuantFusionPass(config)]
127
                if rocm_aiter_ops.is_enabled():
128
                    self.passes += [
129
                        RocmAiterRMSNormQuantFusionPass(config),
130
                    ]
131
            if self.pass_config.fuse_act_quant:
132
                self.passes += [ActivationQuantFusionPass(config)]
133
134
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
135

136
137
138
            if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
                self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]

139
            if self.pass_config.fuse_attn_quant:
140
141
                self.passes += [AttnFusionPass(config)]

142
            if self.pass_config.enable_qk_norm_rope_fusion:
143
                self.passes += [SplitCoalescingPass(config)]
144
145
                self.passes += [QKNormRoPEFusionPass(config)]

146
147
148
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
149

150
    def add(self, pass_: InductorPass) -> None:
151
152
153
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

154
    def uuid(self) -> str:
155
        """
156
157
158
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
159
        """
160
161
162
        passes = []

        state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
163
        for pass_ in self.passes:
164
165
            passes.append(pass_.uuid())
        passes.append(self.fix_functionalization.uuid())
166

167
168
169
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)
170
        state["passes"] = passes
171
        return InductorPass.hash_dict(state)