pass_manager.py 4.93 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4

5
6
from torch import fx as fx

7
from vllm import envs
8
from vllm.config import VllmConfig, set_current_vllm_config
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
from vllm.utils.system_utils import set_env_var
12
13
14

from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
15

16
if current_platform.is_cuda_alike():
17
    from .activation_quant_fusion import ActivationQuantFusionPass
18
    from .fusion import RMSNormQuantFusionPass
19
    from .fusion_attn import AttnFusionPass
20
    from .qk_norm_rope_fusion import QKNormRoPEFusionPass
21
    from .sequence_parallelism import SequenceParallelismPass
22

23
24
25
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

26
from .fix_functionalization import FixFunctionalizationPass
27
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
28
from .noop_elimination import NoOpEliminationPass
29
30
31
32

logger = init_logger(__name__)


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def with_pattern_match_debug(fn):
    """
    Function decorator that turns on inductor pattern match debug
    for the duration of the call.
    Used to avoid logging builtin Inductor pattern matching.
    """

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
            # optionally check rank here
            with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
                return fn(*args, **kwargs)
        return fn(*args, **kwargs)

    return wrapper


51
class PostGradPassManager(CustomGraphPass):
52
53
54
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
55
56
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
57
58
59

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
60
    2. default passes (NoopEliminationPass, FusionPass)
61
62
63
64
65
66
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
67
        self.passes: list[InductorPass] = []
68

69
    @with_pattern_match_debug
70
    def __call__(self, graph: fx.Graph):
71
72
        VllmInductorPass.dump_prefix = 0  # reset dump index

73
        shape = get_pass_context().runtime_shape
74
        for pass_ in self.passes:
75
            if pass_.is_applicable(shape):
76
                pass_(graph)
77
                VllmInductorPass.dump_prefix += 1
78
79
            else:
                logger.debug("Skipping %s with shape %s", pass_, shape)
80
81
82
83
84

        # post-cleanup goes before fix_functionalization
        # because it requires a functional graph
        self.post_cleanup(graph)
        VllmInductorPass.dump_prefix += 1
85
86
87

        # always run fix_functionalization last
        self.fix_functionalization(graph)
88
        VllmInductorPass.dump_prefix = None  # Cleanup index
89

90
91
    def configure(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
92

93
94
95
96
        # Set the current vllm config to allow tracing CustomOp instances
        with set_current_vllm_config(config, check_compile=False):
            if self.pass_config.enable_noop:
                self.passes += [NoOpEliminationPass(config)]
97

98
99
100
101
            if self.pass_config.enable_sequence_parallelism:
                self.passes += [SequenceParallelismPass(config)]
                if self.pass_config.enable_async_tp:
                    self.passes += [AsyncTPPass(config)]
102

103
104
            if self.pass_config.enable_fi_allreduce_fusion:
                self.passes += [AllReduceFusionPass(config)]
105

106
107
108
            if self.pass_config.enable_fusion:
                self.passes += [RMSNormQuantFusionPass(config)]
                self.passes += [ActivationQuantFusionPass(config)]
109

110
111
112
            if self.pass_config.enable_attn_fusion:
                self.passes += [AttnFusionPass(config)]

113
114
115
            if self.pass_config.enable_qk_norm_rope_fusion:
                self.passes += [QKNormRoPEFusionPass(config)]

116
117
118
            # needs a functional graph
            self.post_cleanup = PostCleanupPass(config)
            self.fix_functionalization = FixFunctionalizationPass(config)
119
120
121
122
123

    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

124
    def uuid(self):
125
        """
126
127
128
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
129
        """
130
        state = {"pass_config": self.pass_config.uuid(), "passes": []}
131
132
133
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
134

135
        return InductorPass.hash_dict(state)