fix_functionalization.py 8.19 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
import operator
from typing import Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized

from vllm.logger import init_logger

11
12
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

logger = init_logger(__name__)


class FixFunctionalizationPass(VllmInductorPass):
    """
    This pass defunctionalizes certain nodes to avoid redundant tensor copies.
    After this pass, DCE (dead-code elimination) should never be run,
    as de-functionalized nodes may appear as dead code.

    To add new nodes to defunctionalize, add to the if-elif chain in __call__.
    """

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_fix_functionalization")

        self.nodes_to_remove: List[torch.fx.Node] = []
        count = 0
        for node in graph.nodes:
            if not is_func(node, auto_functionalized):
                continue  # Avoid deep if-elif nesting

            kwargs = node.kwargs
            at_target = node.args[0]

            if at_target == torch.ops._C.rotary_embedding.default:
                query = kwargs['query']
                mm_node = query.args[0].args[0]

                # rotary_embedding is a special case: the two mutating inputs
                # are query and key, which are slices of mm_node.
                # While functionalized, results at[1] and at[2] are scattered
                # back into mm_node. After de-functionalization, we can just
                # use mm_node directly.
                for idx, user in self.getitem_users(node).items():
                    for user_of_getitem in user.users:
                        if is_func(user_of_getitem,
                                   torch.ops.aten.slice_scatter.default):
                            user_of_getitem.replace_all_uses_with(mm_node)
                            self._remove(user_of_getitem)
                    self._remove(user)

                self.insert_defunctionalized(graph, node)
                self._remove(node)

59
            # rms_norm replacements avoid the most copies for LLaMa.
60
61
62
63
64
65
            elif at_target == torch.ops._C.fused_add_rms_norm.default:
                mutated_args = {1: 'input', 2: 'residual'}
                self.defunctionalize(graph, node, mutated_args)
            elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default:  # noqa: E501
                mutated_args = {1: 'result', 2: 'residual'}
                self.defunctionalize(graph, node, mutated_args)
66
67
68
            elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default:  # noqa: E501
                mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
                self.defunctionalize(graph, node, mutated_args)
69
70
            elif at_target in [
                    torch.ops._C.rms_norm.default,
71
                    torch.ops._C.rms_norm_static_fp8_quant.default,
72
73
74
            ]:
                mutated_args = {1: 'result'}
                self.defunctionalize(graph, node, mutated_args)
75
76
77
            # For some reason we need to specify the args for both
            # silu_and_mul and silu_and_mul_quant. The kwargs
            # pathway gets the wrong answer.
78
            elif at_target == torch.ops._C.silu_and_mul.default:
79
80
81
82
83
84
85
                mutated_args = {1: 'result'}
                self.defunctionalize(graph,
                                     node,
                                     mutated_args,
                                     args=('result', 'input'))
            elif at_target == torch.ops._C.silu_and_mul_quant.default:
                mutated_args = {1: 'result'}
86
87
88
                self.defunctionalize(graph,
                                     node,
                                     mutated_args,
89
                                     args=('result', 'input', 'scale'))
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
            else:
                continue  # skip the count

            count += 1

        self.dump_graph(graph, "before_fix_functionalization_cleanup")

        # Remove the nodes all at once
        count_removed = len(self.nodes_to_remove)
        for node in self.nodes_to_remove:
            graph.erase_node(node)

        logger.debug("De-functionalized %s nodes, removed %s nodes", count,
                     count_removed)
        self.dump_graph(graph, "after_fix_functionalization")
        self.end_and_log()

    def _remove(self, node_or_nodes: Union[torch.fx.Node,
                                           Iterable[torch.fx.Node]]):
        """
        Stage a node (or nodes) for removal at the end of the pass.
        """
        if isinstance(node_or_nodes, torch.fx.Node):
            self.nodes_to_remove.append(node_or_nodes)
        else:
            self.nodes_to_remove.extend(node_or_nodes)

    def defunctionalize(self,
                        graph: torch.fx.Graph,
                        node: torch.fx.Node,
                        mutated_args: Dict[int, Union[torch.fx.Node, str]],
                        args: Optional[Tuple[Union[torch.fx.Node, str],
                                             ...]] = None):
        """
        De-functionalize a node by replacing it with a call to the original.
        It also replaces the getitem users with the mutated arguments.
        See replace_users_with_mutated_args and insert_defunctionalized.
        """
        self.replace_users_with_mutated_args(node, mutated_args)
        self.insert_defunctionalized(graph, node, args=args)
        self._remove(node)

    def replace_users_with_mutated_args(self, node: torch.fx.Node,
                                        mutated_args: Dict[int,
                                                           Union[torch.fx.Node,
                                                                 str]]):
        """
        Replace all getitem users of the auto-functionalized node with the
        mutated arguments.
        :param node: The auto-functionalized node
        :param mutated_args: The mutated arguments, indexed by getitem index.
        If the value of an arg is a string, `node.kwargs[arg]` is used.
        """
        for idx, user in self.getitem_users(node).items():
            arg = mutated_args[idx]
            arg = node.kwargs[arg] if isinstance(arg, str) else arg
            user.replace_all_uses_with(arg)
            self._remove(user)

    def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
        """
        Returns the operator.getitem users of the auto-functionalized node,
        indexed by the index they are getting.
        """
        users = {}
        for user in node.users:
            if is_func(user, operator.getitem):
                idx = user.args[1]
                users[idx] = user
        return users

    def insert_defunctionalized(self,
                                graph: torch.fx.Graph,
                                node: torch.fx.Node,
                                args: Optional[Tuple[Union[torch.fx.Node, str],
                                                     ...]] = None):
        """
        Insert a new defunctionalized node into the graph before node.
        If one of the kwargs is 'out', provide args directly,
        as node.kwargs cannot be used.
        See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351

        :param graph: Graph to insert the defunctionalized node into
        :param node: The auto-functionalized node to defunctionalize
        :param args: If we cannot use kwargs, specify args directly.
        If an arg is a string, `node.kwargs[arg]` is used.
        """  # noqa: E501
        assert is_func(node, auto_functionalized), \
            f"node must be auto-functionalized, is {node} instead"

        # Create a new call to the original function
        with graph.inserting_before(node):
            function = node.args[0]
            if args is None:
                graph.call_function(function, kwargs=node.kwargs)
            else:
                # Args passed as strings refer to items in node.kwargs
                args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg
                             for arg in args)
                graph.call_function(function, args=args)