pass_manager.py 3.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from torch import fx as fx

6
from vllm.config import VllmConfig
7
from vllm.logger import init_logger
8
9
from vllm.platforms import current_platform

10
if current_platform.is_cuda_alike():
11
    from .activation_quant_fusion import ActivationQuantFusionPass
12
13
    from .fusion import FusionPass
    from .fusion_attn import AttnFusionPass
14

15
16
17
if current_platform.is_cuda():
    from .collective_fusion import AllReduceFusionPass, AsyncTPPass

18
from .fix_functionalization import FixFunctionalizationPass
19
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
20
from .noop_elimination import NoOpEliminationPass
21
22
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
23
24
25
26

logger = init_logger(__name__)


27
class PostGradPassManager(CustomGraphPass):
28
29
30
    """
    The pass manager for post-grad passes.
    It handles configuration, adding custom passes, and running passes.
31
32
    It supports uuid for the Inductor code cache. That includes torch<2.6
    support using pickling (in .inductor_pass.CustomGraphPass).
33
34
35

    The order of the post-grad post-passes is:
    1. passes (constructor parameter)
36
    2. default passes (NoopEliminationPass, FusionPass)
37
38
39
40
41
42
    3. config["post_grad_custom_post_pass"] (if it exists)
    4. fix_functionalization
    This way, all passes operate on a functionalized graph.
    """

    def __init__(self):
43
        self.passes: list[VllmInductorPass] = []
44
45

    def __call__(self, graph: fx.Graph):
46
        shape = get_pass_context().runtime_shape
47
        for pass_ in self.passes:
48
49
            if pass_.is_applicable_for_shape(shape):
                pass_(graph)
50
51
52
53

        # always run fix_functionalization last
        self.fix_functionalization(graph)

54
55
56
57
    def configure(self, config: VllmConfig):
        self.pass_config = config.compilation_config.pass_config
        if self.pass_config.enable_noop:
            self.passes += [NoOpEliminationPass(config)]
58

59
60
        if self.pass_config.enable_sequence_parallelism:
            self.passes += [SequenceParallelismPass(config)]
61
62
            if self.pass_config.enable_async_tp:
                self.passes += [AsyncTPPass(config)]
63

64
65
66
67
        if self.pass_config.enable_fusion:
            self.passes += [FusionPass.instance(config)]
            self.passes += [ActivationQuantFusionPass(config)]

68
69
        if self.pass_config.enable_attn_fusion:
            self.passes += [AttnFusionPass(config)]
70
        if self.pass_config.enable_fi_allreduce_fusion:
71
            self.passes += [AllReduceFusionPass(config)]
72
        self.fix_functionalization = FixFunctionalizationPass(config)
73
74
75
76
77

    def add(self, pass_: InductorPass):
        assert isinstance(pass_, InductorPass)
        self.passes.append(pass_)

78
    def uuid(self):
79
        """
80
81
82
        The PostGradPassManager is set as a custom pass in the Inductor and
        affects compilation caching. Its uuid depends on the UUIDs of all
        dependent passes and the pass config. See InductorPass for more info.
83
84
85
86
87
        """
        state = {"pass_config": self.pass_config.uuid(), "passes": []}
        for pass_ in self.passes:
            state["passes"].append(pass_.uuid())
        state["passes"].append(self.fix_functionalization.uuid())
88
        return InductorPass.hash_dict(state)