"vscode:/vscode.git/clone" did not exist on "038914b7c891c0b5b2853ec0574062dc3bea8073"
pass_manager.py 5.22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4

5
6
from torch import fx as fx

7
from vllm import envs
8
from vllm.config import VllmConfig, set_current_vllm_config
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
from vllm.utils.system_utils import set_env_var
12
13
14

from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
15

16
if current_platform.is_cuda_alike():
17
    from .activation_quant_fusion import ActivationQuantFusionPass
18
    from .fusion import RMSNormQuantFusionPass
19
    from .fusion_attn import AttnFusionPass
20
    from .qk_norm_rope_fusion import QKNormRoPEFusionPass
21
    from .sequence_parallelism import SequenceParallelismPass
22

23
24
25
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

26
from .fix_functionalization import FixFunctionalizationPass
27
28
29
30
31
from .inductor_pass import (
    CustomGraphPass,
    InductorPass,
    get_pass_context,
)
32
from .noop_elimination import NoOpEliminationPass
33
34
35
36

logger = init_logger(__name__)


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def with_pattern_match_debug(fn):
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


55
class PostGradPassManager(CustomGraphPass):
56
57
58
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
59
60
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
61
62
63

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
64
    2. default passes (NoopEliminationPass, FusionPass)
65
66
67
68
69
70
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
71
        self.passes: list[InductorPass] = []
72

73
    @with_pattern_match_debug
74
    def __call__(self, graph: fx.Graph):
75
76
        VllmInductorPass.dump_prefix = 0  # reset dump index

77
        compile_range = get_pass_context().compile_range
78
        for pass_ in self.passes:
79
            if pass_.is_applicable_for_range(compile_range):
80
                pass_(graph)
81
                VllmInductorPass.dump_prefix += 1
82
            else:
83
                logger.debug("Skipping %s with compile range %s", pass_, compile_range)
84
85
86
87
88

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
89
90
91

        # always run fix_functionalization last
        self.fix_functionalization(graph)
92
        VllmInductorPass.dump_prefix = None  # Cleanup index
93

94
95
    def configure(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
96

97
98
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
99
            if self.pass_config.eliminate_noops:
100
                self.passes += [NoOpEliminationPass(config)]
101

102
            if self.pass_config.enable_sp:
103
                self.passes += [SequenceParallelismPass(config)]
104
                if self.pass_config.fuse_gemm_comms:
105
                    self.passes += [AsyncTPPass(config)]
106

107
            if self.pass_config.fuse_allreduce_rms:
108
                self.passes += [AllReduceFusionPass(config)]
109

110
            if self.pass_config.fuse_norm_quant:
111
                self.passes += [RMSNormQuantFusionPass(config)]
112
            if self.pass_config.fuse_act_quant:
113
                self.passes += [ActivationQuantFusionPass(config)]
114

115
            if self.pass_config.fuse_attn_quant:
116
117
                self.passes += [AttnFusionPass(config)]

118
119
120
            if self.pass_config.enable_qk_norm_rope_fusion:
                self.passes += [QKNormRoPEFusionPass(config)]

121
122
123
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
124
125
126
127
128

    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

129
    def uuid(self):
130
        """
131
132
133
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
134
        """
135
        state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
136
137
138
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
139

140
141
142
143
        # Include the compile range in the uuid to ensure that inductor
        # recompiles the graph for the new dynamic compile range.
        state["compile_range"] = str(get_pass_context().compile_range)

144
        return InductorPass.hash_dict(state)