pass_manager.py 5.73 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4

5
6
from torch import fx as fx

7
from vllm import envs
8
from vllm._aiter_ops import rocm_aiter_ops
9
from vllm.config import VllmConfig, set_current_vllm_config
10
from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12
from vllm.utils.system_utils import set_env_var
13
14
15

from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
16

17
18
if rocm_aiter_ops.is_enabled():
    from vllm.compilation.rocm_aiter_fusion import (
19
        RocmAiterRMSNormFusionPass,
20
21
22
        RocmAiterSiluMulFp8GroupQuantFusionPass,
    )

23
if current_platform.is_cuda_alike():
24
    from .activation_quant_fusion import ActivationQuantFusionPass
25
    from .fusion import RMSNormQuantFusionPass
26
    from .fusion_attn import AttnFusionPass
27
    from .qk_norm_rope_fusion import QKNormRoPEFusionPass
28
    from .sequence_parallelism import SequenceParallelismPass
29

30
31
32
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

33
from .fix_functionalization import FixFunctionalizationPass
34
35
36
37
38
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
39
from .noop_elimination import NoOpEliminationPass
40
41
42
43

logger = init_logger(__name__)


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def with_pattern_match_debug(fn):
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


62
class PostGradPassManager(CustomGraphPass):
63
64
65
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
66
67
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
68
69
70

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
71
    2. default passes (NoopEliminationPass, FusionPass)
72
73
74
75
76
77
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
78
        self.passes: list[InductorPass] = []
79

80
    @with_pattern_match_debug
81
    def __call__(self, graph: fx.Graph):
82
83
        VllmInductorPass.dump_prefix = 0  # reset dump index

84
        compile_range = get_pass_context().compile_range
85
        for pass_ in self.passes:
86
            if pass_.is_applicable_for_range(compile_range):
87
                pass_(graph)
88
                VllmInductorPass.dump_prefix += 1
89
            else:
90
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
91
92
93
94
95

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
96
97
98

        # always run fix_functionalization last
        self.fix_functionalization(graph)
99
        VllmInductorPass.dump_prefix = None  # Cleanup index
100

101
102
    def configure(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
103

104
105
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
106
            if self.pass_config.eliminate_noops:
107
                self.passes += [NoOpEliminationPass(config)]
108

109
            if self.pass_config.enable_sp:
110
                self.passes += [SequenceParallelismPass(config)]
111
                if self.pass_config.fuse_gemm_comms:
112
                    self.passes += [AsyncTPPass(config)]
113

114
            if self.pass_config.fuse_allreduce_rms:
115
                self.passes += [AllReduceFusionPass(config)]
116

117
            if self.pass_config.fuse_norm_quant:
118
                self.passes += [RMSNormQuantFusionPass(config)]
119
                if rocm_aiter_ops.is_enabled():
120
121
122
                    self.passes += [
                        RocmAiterRMSNormFusionPass(config),
                    ]
123
            if self.pass_config.fuse_act_quant:
124
                self.passes += [ActivationQuantFusionPass(config)]
125
126
                if rocm_aiter_ops.is_enabled():
                    self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
127

128
            if self.pass_config.fuse_attn_quant:
129
130
                self.passes += [AttnFusionPass(config)]

131
132
133
            if self.pass_config.enable_qk_norm_rope_fusion:
                self.passes += [QKNormRoPEFusionPass(config)]

134
135
136
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
137
138
139
140
141

    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

142
    def uuid(self):
143
        """
144
145
146
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
147
        """
148
        state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
149
150
151
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
152

153
154
155
156
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)

157
        return InductorPass.hash_dict(state)