noop_elimination.py 6.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Union
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

import torch.fx
from torch import SymInt

from vllm.logger import init_logger

from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)


class NoOpEliminationPass(VllmInductorPass):
    """
    This is an inductor pass that removes redundant reshape/slice operations.
    It is required for RMSNorm-quant fusion to work properly.
    That's because apply_fp8_linear adds a reshape, which is redundant
    in the 2D-case. Additionally, torch internal no-op elimination pass does
    not handle certain slice variants.

26
27
28
29
30
31
    Cases handled:
      1. A chain of reshapes is equivalent to the last reshape called on the
      base tensor (input of the first reshape).
      2. A reshape that produces the shape of the input is redundant
      3. A slice that produces the shape of the input is redundant

32
    Example graph 1:
33
34
35
36
37
38
39
40
41
42
    mul_1: "f16[s0, 4096]" = ...
    view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
    view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
    view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])

    Can be replaced with:
    mul_1: "f16[s0, 4096]" = ...
    view_3: "f16[s0, 128, 32]" = ...

    Example graph 2:
43
44
45
46
47
48
49
50
51
52
    getitem_1: "f16[s0, 4096]" = ...
    view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
    at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

    Can be replaced with:
    getitem_1: "f16[s0, 4096]" = ...
    at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

53
    Example graph 3:
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    arg0: "s0" = SymInt(s0)
    scaled_mm: "f16[s0, 4096]" = ...
    slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
    at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
    out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)

    Can be replaced with:
    arg0: "s0" = SymInt(s0)
    scaled_mm: "f16[s0, 4096]" = ...
    at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
    out: "f16[s0, 4096]" = at[1]

    TODO(luka): This is currently tested in test_fusion,
     but separate tests could be good.
    """

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_noop_elimination")
        count = 0
        # Remove no-op reshapes/views:
        for node in graph.nodes:
            if is_func(node, torch.ops.aten.reshape.default):
77
78
79
80
81
82
83
84
85
86
87
88
                # Case 1: rewrite reshape chains to reshapes on the base tensor
                input = node.args[0]
                # If the input is a reshape, rebind to that node
                if is_func(input, torch.ops.aten.reshape.default):
                    # The new input is guaranteed not to be a reshape,
                    # because we process nodes in order
                    node.update_arg(0, input.args[0])
                    if len(input.users) == 0:
                        graph.erase_node(input)
                        count += 1

                # Case 2: remove this reshape if it produces the original shape
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                input, shape = node.args[:2]
                input_shape = input.meta["val"].shape
                if len(shape) != len(input_shape):
                    # Reshape changing rank, skip
                    continue

                if shape.count(-1) > 1:
                    # Invalid reshape args, skip
                    continue

                if self.all_dims_equivalent(shape, input_shape):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

            elif is_func(node, torch.ops.aten.slice.Tensor):
                input, dim_index, start, end = node.args[:4]
                input_shape = input.meta["val"].shape
                i_dim = input_shape[dim_index]

                if start == 0 and self.dims_equivalent(end, i_dim):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

            elif is_func(node, torch.ops.aten.slice_scatter.default):
                base, view, dim_index, start, end = node.args[:5]
                base_shape = base.meta["val"].shape
                view_shape = view.meta["val"].shape

                view_dim = view_shape[dim_index]

                # Check that view fully covers base and the full view is used
                # (if the view fully covered the base after slicing but was not
                # fully used, we could replace slice_scatter with a simple slice
                # but that's a niche case).
                if (base_shape == view_shape and start == 0
                        and self.dims_equivalent(end, view_dim)):
                    node.replace_all_uses_with(view)
                    graph.erase_node(node)
                    count += 1

        logger.debug("Removed %s no-op reshapes and slices", count)
        self.dump_graph(graph, "after_noop_elimination")
        self.end_and_log()

    def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
                            i_dims: Iterable[Union[int, SymInt]]):
        return all(
            self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

    def dims_equivalent(self, dim: Union[int, torch.fx.Node],
                        i_dim: Union[int, SymInt]) -> bool:
        """
        This function checks if two dimensions are equivalent.
        :param dim: The dimension arg to reshape/slice
        :param i_dim: The corresponding dimension in the input tensor
        :return: Are the dimensions equivalent?

        There are three cases in which the dimensions are equivalent:
        1. The dimensions are equal (both integers)
        2. The reshape dimension is -1 (i.e. inferred)
        3. The dimensions both correspond to the same SymInt

        While case 2 does not guarantee the dimensions are equal,
        they are equal if all other dimensions are equal.

        In case 3, the reshape dimension is a torch.fx.Node,
        and its value is a SymInt. That value is equal to the
        input dimension.

        """
        # Case 1 and 2
        if dim == i_dim or dim == -1:
            return True
        # Case 3
        return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim