pass_manager.py 6.21 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
6

7
8
from torch import fx as fx

9
from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.config import VllmConfig, set_current_vllm_config
12
from vllm.logger import init_logger
13
from vllm.platforms import current_platform
14
from vllm.utils.system_utils import set_env_var
15
16
17

from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
18

19
20
if rocm_aiter_ops.is_enabled():
    from vllm.compilation.rocm_aiter_fusion import (
21
        RocmAiterRMSNormQuantFusionPass,
22
        RocmAiterSiluMulFp8GroupQuantFusionPass,
23
        RocmAiterTritonAddRMSNormPadFusionPass,
24
25
    )

26
if current_platform.is_cuda_alike():
27
    from .activation_quant_fusion import ActivationQuantFusionPass
28
    from .fusion import RMSNormQuantFusionPass
29
    from .fusion_attn import AttnFusionPass
30
    from .qk_norm_rope_fusion import QKNormRoPEFusionPass
31
    from .sequence_parallelism import SequenceParallelismPass
32

33
34
35
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

36
from .fix_functionalization import FixFunctionalizationPass
37
38
39
40
41
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
42
from .noop_elimination import NoOpEliminationPass
43
44
45

logger = init_logger(__name__)

46
47
P = ParamSpec("P")
R = TypeVar("R")
48

49
50

def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
51
52
53
54
55
56
57
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
58
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
59
60
61
62
63
64
65
66
67
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


68
class PostGradPassManager(CustomGraphPass):  # type: ignore[misc]
69
70
71
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
72
73
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
74
75
76

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
77
    2. default passes (NoopEliminationPass, FusionPass)
78
79
80
81
82
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

83
    def __init__(self) -> None:
84
        self.passes: list[InductorPass] = []
85

86
    @with_pattern_match_debug
87
    def __call__(self, graph: fx.Graph) -> None:
88
89
        VllmInductorPass.dump_prefix = 0  # reset dump index

90
        compile_range = get_pass_context().compile_range
91
        for pass_ in self.passes:
92
            if pass_.is_applicable_for_range(compile_range):
93
                pass_(graph)
94
                VllmInductorPass.dump_prefix += 1
95
            else:
96
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
97
98
99
100
101

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
102
103
104

        # always run fix_functionalization last
        self.fix_functionalization(graph)
105
        VllmInductorPass.dump_prefix = None  # Cleanup index
106

107
    def configure(self, config: VllmConfig) -> None:
108
        self.pass_config = config.compilation_config.pass_config
109

110
111
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
112
            if self.pass_config.eliminate_noops:
113
                self.passes += [NoOpEliminationPass(config)]
114

115
            if self.pass_config.enable_sp:
116
                self.passes += [SequenceParallelismPass(config)]
117
                if self.pass_config.fuse_gemm_comms:
118
                    self.passes += [AsyncTPPass(config)]
119

120
            if self.pass_config.fuse_allreduce_rms:
121
                self.passes += [AllReduceFusionPass(config)]
122

123
            if self.pass_config.fuse_norm_quant:
124
                self.passes += [RMSNormQuantFusionPass(config)]
125
                if rocm_aiter_ops.is_enabled():
126
                    self.passes += [
127
                        RocmAiterRMSNormQuantFusionPass(config),
128
                    ]
129
            if self.pass_config.fuse_act_quant:
130
                self.passes += [ActivationQuantFusionPass(config)]
131
132
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
133

134
135
136
            if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
                self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]

137
            if self.pass_config.fuse_attn_quant:
138
139
                self.passes += [AttnFusionPass(config)]

140
141
142
            if self.pass_config.enable_qk_norm_rope_fusion:
                self.passes += [QKNormRoPEFusionPass(config)]

143
144
145
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
146

147
    def add(self, pass_: InductorPass) -> None:
148
149
150
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

151
    def uuid(self) -> str:
152
        """
153
154
155
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
156
        """
157
158
159
        passes = []

        state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
160
        for pass_ in self.passes:
161
162
            passes.append(pass_.uuid())
        passes.append(self.fix_functionalization.uuid())
163

164
165
166
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)
167
        state["passes"] = passes
168
        return InductorPass.hash_dict(state)