pass_manager.py 6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
6

7
8
from torch import fx as fx

9
from vllm import envs
10
from vllm._aiter_ops import rocm_aiter_ops
11
from vllm.config import VllmConfig, set_current_vllm_config
12
from vllm.logger import init_logger
13
from vllm.platforms import current_platform
14
from vllm.utils.system_utils import set_env_var
15
16
17

from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
18

19
20
if rocm_aiter_ops.is_enabled():
    from vllm.compilation.rocm_aiter_fusion import (
21
        RocmAiterRMSNormFusionPass,
22
23
24
        RocmAiterSiluMulFp8GroupQuantFusionPass,
    )

25
if current_platform.is_cuda_alike():
26
    from .activation_quant_fusion import ActivationQuantFusionPass
27
    from .fusion import RMSNormQuantFusionPass
28
    from .fusion_attn import AttnFusionPass
29
    from .qk_norm_rope_fusion import QKNormRoPEFusionPass
30
    from .sequence_parallelism import SequenceParallelismPass
31

32
33
34
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

35
from .fix_functionalization import FixFunctionalizationPass
36
37
38
39
40
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
41
from .noop_elimination import NoOpEliminationPass
42
43
44

logger = init_logger(__name__)

45
46
P = ParamSpec("P")
R = TypeVar("R")
47

48
49

def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
50
51
52
53
54
55
56
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
57
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
58
59
60
61
62
63
64
65
66
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


67
class PostGradPassManager(CustomGraphPass):  # type: ignore[misc]
68
69
70
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
71
72
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
73
74
75

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
76
    2. default passes (NoopEliminationPass, FusionPass)
77
78
79
80
81
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

82
    def __init__(self) -> None:
83
        self.passes: list[InductorPass] = []
84

85
    @with_pattern_match_debug
86
    def __call__(self, graph: fx.Graph) -> None:
87
88
        VllmInductorPass.dump_prefix = 0  # reset dump index

89
        compile_range = get_pass_context().compile_range
90
        for pass_ in self.passes:
91
            if pass_.is_applicable_for_range(compile_range):
92
                pass_(graph)
93
                VllmInductorPass.dump_prefix += 1
94
            else:
95
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
96
97
98
99
100

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
101
102
103

        # always run fix_functionalization last
        self.fix_functionalization(graph)
104
        VllmInductorPass.dump_prefix = None  # Cleanup index
105

106
    def configure(self, config: VllmConfig) -> None:
107
        self.pass_config = config.compilation_config.pass_config
108

109
110
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
111
            if self.pass_config.eliminate_noops:
112
                self.passes += [NoOpEliminationPass(config)]
113

114
            if self.pass_config.enable_sp:
115
                self.passes += [SequenceParallelismPass(config)]
116
                if self.pass_config.fuse_gemm_comms:
117
                    self.passes += [AsyncTPPass(config)]
118

119
            if self.pass_config.fuse_allreduce_rms:
120
                self.passes += [AllReduceFusionPass(config)]
121

122
            if self.pass_config.fuse_norm_quant:
123
                self.passes += [RMSNormQuantFusionPass(config)]
124
                if rocm_aiter_ops.is_enabled():
125
126
127
                    self.passes += [
                        RocmAiterRMSNormFusionPass(config),
                    ]
128
            if self.pass_config.fuse_act_quant:
129
                self.passes += [ActivationQuantFusionPass(config)]
130
131
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
132

133
            if self.pass_config.fuse_attn_quant:
134
135
                self.passes += [AttnFusionPass(config)]

136
137
138
            if self.pass_config.enable_qk_norm_rope_fusion:
                self.passes += [QKNormRoPEFusionPass(config)]

139
140
141
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
142

143
    def add(self, pass_: InductorPass) -> None:
144
145
146
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

147
    def uuid(self) -> str:
148
        """
149
150
151
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
152
        """
153
154
155
        passes = []

        state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
156
        for pass_ in self.passes:
157
158
            passes.append(pass_.uuid())
        passes.append(self.fix_functionalization.uuid())
159

160
161
162
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)
163
        state["passes"] = passes
164
        return InductorPass.hash_dict(state)