fix_functionalization.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import operator
5
from collections.abc import Iterable
6
7
8
9
10

import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized

from vllm.logger import init_logger
11
from vllm.platforms import current_platform
12

13
14
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
15
16
17
18
19
20
21
22
23
24
25
26
27

logger = init_logger(__name__)


class FixFunctionalizationPass(VllmInductorPass):
    """
    This pass defunctionalizes certain nodes to avoid redundant tensor copies.
    After this pass, DCE (dead-code elimination) should never be run,
    as de-functionalized nodes may appear as dead code.

    To add new nodes to defunctionalize, add to the if-elif chain in __call__.
    """

28
    @VllmInductorPass.time_and_log
29
    def __call__(self, graph: torch.fx.Graph) -> None:
30
31
32
        # XPU does not support auto-functionalization yet.
        # Will enable this when switch to vllm-xpu-kernels.
        if current_platform.is_xpu():
33
34
35
            logger.debug(
                "XPU platform does not support fix functionalizationpass currently."
            )
36
37
            return

38
        self.nodes_to_remove: list[torch.fx.Node] = []
39
        count = 0
40
41
42
43
44
45
46
47

        rope_targets = [torch.ops._C.rotary_embedding.default]

        if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
            rope_targets.append(
                torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
            )

48
49
50
51
52
53
54
        for node in graph.nodes:
            if not is_func(node, auto_functionalized):
                continue  # Avoid deep if-elif nesting

            kwargs = node.kwargs
            at_target = node.args[0]

55
            if at_target in rope_targets:
56
57
                query = kwargs["query"]
                key = kwargs["key"]
58
59
                getitem_nodes = self.getitem_users(node)

60
61
62
63
64
65
66
67
68
69
70
                if (
                    is_func(query, operator.getitem)
                    and is_func(key, operator.getitem)
                    and query.args[0] == key.args[0]
                    and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
                    and all(
                        is_func(user, torch.ops.aten.slice_scatter.default)
                        for getitem_node in getitem_nodes.values()
                        for user in getitem_node.users
                    )
                ):
71
72
73
74
75
76
77
78
                    # Pattern where query and key are slices of an mm_node.
                    # While functionalized, results at [1] and [2] are scattered
                    # back into mm_node. So after de-functionalization, we can
                    # just use mm_node directly.

                    mm_node = query.args[0].args[0]
                    for user in getitem_nodes.values():
                        for user_of_getitem in user.users:
79
80
81
                            if is_func(
                                user_of_getitem, torch.ops.aten.slice_scatter.default
                            ):
82
83
84
85
86
87
88
89
90
91
92
93
94
                                user_of_getitem.replace_all_uses_with(mm_node)
                                self._remove(user_of_getitem)
                        self._remove(user)

                    self.insert_defunctionalized(graph, node)
                    self._remove(node)

                else:
                    # Directly replace the auto_functionalize(rotary_embedding)
                    # with the inplace rotary_embedding. In theory, we shouldn't
                    # do this blindly, but in practice in vLLM it's ok. The best
                    # solution is to use auto_functionalization_v2 and then use
                    # inductor's builtin defunctionalization (reinplacing) pass.
95
                    mutated_args = {1: "query", 2: "key"}
96
                    self.defunctionalize(graph, node, mutated_args)
97

98
            # rms_norm replacements avoid the most copies for LLaMa.
99
            elif at_target == torch.ops._C.fused_add_rms_norm.default:
100
                mutated_args = {1: "input", 2: "residual"}
101
102
                self.defunctionalize(graph, node, mutated_args)
            elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default:  # noqa: E501
103
                mutated_args = {1: "result", 2: "residual"}
104
                self.defunctionalize(graph, node, mutated_args)
105
            elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default:  # noqa: E501
106
                mutated_args = {1: "result", 2: "scale", 3: "residual"}
107
                self.defunctionalize(graph, node, mutated_args)
108
            elif at_target in [
109
110
                torch.ops._C.rms_norm.default,
                torch.ops._C.rms_norm_static_fp8_quant.default,
111
            ]:
112
                mutated_args = {1: "result"}
113
                self.defunctionalize(graph, node, mutated_args)
114
            elif (
115
116
                hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
                and at_target
117
118
119
120
121
122
123
124
125
126
                == torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
            ):
                mutated_args = {
                    1: "allreduce_in",
                    2: "residual",
                    3: "norm_out",
                    4: "quant_out",
                    5: "scale_out",
                }
                self.defunctionalize(graph, node, mutated_args)
127
128
129
            # For some reason we need to specify the args for both
            # silu_and_mul and silu_and_mul_quant. The kwargs
            # pathway gets the wrong answer.
130
            elif at_target == torch.ops._C.silu_and_mul.default:
131
132
133
134
                mutated_args = {1: "result"}
                self.defunctionalize(
                    graph, node, mutated_args, args=("result", "input")
                )
135
            elif at_target == torch.ops._C.silu_and_mul_quant.default:
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                mutated_args = {1: "result"}
                self.defunctionalize(
                    graph, node, mutated_args, args=("result", "input", "scale")
                )
            elif (
                hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
                and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
            ):
                mutated_args = {1: "result", 2: "result_block_scale"}
                self.defunctionalize(
                    graph,
                    node,
                    mutated_args,
                    args=(
                        "result",
                        "result_block_scale",
                        "input",
                        "input_global_scale",
                    ),
                )
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
            # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
            elif at_target == torch.ops._C.fused_qk_norm_rope.default:
                mutated_args = {1: "qkv"}
                args = (
                    "qkv",
                    "num_heads_q",
                    "num_heads_k",
                    "num_heads_v",
                    "head_dim",
                    "eps",
                    "q_weight",
                    "k_weight",
                    "cos_sin_cache",
                    "is_neox",
                    "position_ids",
171
                    "forced_token_heads_per_warp",
172
173
                )
                self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            elif (
                hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
                and at_target
                == torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
            ):
                mutated_args = {
                    1: "query",
                    2: "key",
                }
                self.defunctionalize(graph, node, mutated_args=mutated_args)
            # only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
            elif (
                hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
                and at_target
                == torch.ops.vllm.function_with_mutated_args_and_return.default
            ):
                mutated_args = {1: "x"}
                self.defunctionalize(graph, node, mutated_args=mutated_args)
192
193
194
195
196
            else:
                continue  # skip the count

            count += 1

197
        self.dump_graph(graph, "before_cleanup")
198
199
200
201
202
203

        # Remove the nodes all at once
        count_removed = len(self.nodes_to_remove)
        for node in self.nodes_to_remove:
            graph.erase_node(node)

204
205
206
        logger.debug(
            "De-functionalized %s nodes, removed %s nodes", count, count_removed
        )
207
        self.nodes_to_remove.clear()
208

209
    def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
210
211
212
213
214
215
216
217
        """
        Stage a node (or nodes) for removal at the end of the pass.
        """
        if isinstance(node_or_nodes, torch.fx.Node):
            self.nodes_to_remove.append(node_or_nodes)
        else:
            self.nodes_to_remove.extend(node_or_nodes)

218
219
220
221
    def defunctionalize(
        self,
        graph: torch.fx.Graph,
        node: torch.fx.Node,
222
223
        mutated_args: dict[int, torch.fx.Node | str],
        args: tuple[torch.fx.Node | str, ...] | None = None,
224
    ) -> None:
225
226
227
228
229
230
231
232
233
        """
        De-functionalize a node by replacing it with a call to the original.
        It also replaces the getitem users with the mutated arguments.
        See replace_users_with_mutated_args and insert_defunctionalized.
        """
        self.replace_users_with_mutated_args(node, mutated_args)
        self.insert_defunctionalized(graph, node, args=args)
        self._remove(node)

234
    def replace_users_with_mutated_args(
235
        self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
236
    ) -> None:
237
        """
238
        Replace mutated getitem users of the auto-functionalized node with the
239
240
241
242
243
244
        mutated arguments.
        :param node: The auto-functionalized node
        :param mutated_args: The mutated arguments, indexed by getitem index.
        If the value of an arg is a string, `node.kwargs[arg]` is used.
        """
        for idx, user in self.getitem_users(node).items():
245
246
247
248
249
250
251
            # Some functionalized nodes may return both a result at getitem[0]
            # as well as mutated args at getitem[1:...]
            if idx == 0:
                assert idx not in mutated_args, (
                    f"result at getitem[0] should not be in mutated_args for {node}"
                )
                continue
252
253
254
255
256
            arg = mutated_args[idx]
            arg = node.kwargs[arg] if isinstance(arg, str) else arg
            user.replace_all_uses_with(arg)
            self._remove(user)

257
    def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
258
259
260
261
262
263
264
265
266
267
268
        """
        Returns the operator.getitem users of the auto-functionalized node,
        indexed by the index they are getting.
        """
        users = {}
        for user in node.users:
            if is_func(user, operator.getitem):
                idx = user.args[1]
                users[idx] = user
        return users

269
270
271
272
    def insert_defunctionalized(
        self,
        graph: torch.fx.Graph,
        node: torch.fx.Node,
273
        args: tuple[torch.fx.Node | str, ...] | None = None,
274
    ) -> None:
275
276
277
278
279
280
281
282
283
284
285
        """
        Insert a new defunctionalized node into the graph before node.
        If one of the kwargs is 'out', provide args directly,
        as node.kwargs cannot be used.
        See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351

        :param graph: Graph to insert the defunctionalized node into
        :param node: The auto-functionalized node to defunctionalize
        :param args: If we cannot use kwargs, specify args directly.
        If an arg is a string, `node.kwargs[arg]` is used.
        """  # noqa: E501
286
        assert is_func(node, auto_functionalized), (
287
            f"node must be auto-functionalized, is {node} instead"
288
        )
289
290
291
292
293

        # Create a new call to the original function
        with graph.inserting_before(node):
            function = node.args[0]
            if args is None:
294
                fn_node = graph.call_function(function, kwargs=node.kwargs)
295
296
            else:
                # Args passed as strings refer to items in node.kwargs
297
298
299
                args = tuple(
                    node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
                )
300
301
302
303
304
305
306
307
308
309
310
                fn_node = graph.call_function(function, args=args)

        # If the function returns a value as well as mutating args inplace,
        # the functionalized node will have a getitem[0] user that holds this value
        # Replace getitem[0] user of the auto-functionalized node
        # with the new defunctionalized node directly if it exists
        users = self.getitem_users(node)
        if 0 in users:
            user = users[0]
            user.replace_all_uses_with(fn_node)
            self._remove(user)