"vllm/tool_parsers/granite_tool_parser.py" did not exist on "d6953beb91da4e9c99be4c0a1304a2d24189535c"
noop_elimination.py 5.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# SPDX-License-Identifier: Apache-2.0

from typing import Iterable, Union

import torch.fx
from torch import SymInt

from vllm.logger import init_logger

from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass

logger = init_logger(__name__)


class NoOpEliminationPass(VllmInductorPass):
    """
    This is an inductor pass that removes redundant reshape/slice operations.
    It is required for RMSNorm-quant fusion to work properly.
    That's because apply_fp8_linear adds a reshape, which is redundant
    in the 2D-case. Additionally, torch internal no-op elimination pass does
    not handle certain slice variants.

    Example graph 1:
    getitem_1: "f16[s0, 4096]" = ...
    view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
    at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

    Can be replaced with:
    getitem_1: "f16[s0, 4096]" = ...
    at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

    Example graph 2:
    arg0: "s0" = SymInt(s0)
    scaled_mm: "f16[s0, 4096]" = ...
    slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
    at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
    out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)

    Can be replaced with:
    arg0: "s0" = SymInt(s0)
    scaled_mm: "f16[s0, 4096]" = ...
    at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
    out: "f16[s0, 4096]" = at[1]

    TODO(luka): This is currently tested in test_fusion,
     but separate tests could be good.
    """

    def __call__(self, graph: torch.fx.Graph):
        self.begin()
        self.dump_graph(graph, "before_noop_elimination")
        count = 0
        # Remove no-op reshapes/views:
        for node in graph.nodes:
            if is_func(node, torch.ops.aten.reshape.default):
                input, shape = node.args[:2]
                input_shape = input.meta["val"].shape
                if len(shape) != len(input_shape):
                    # Reshape changing rank, skip
                    continue

                if shape.count(-1) > 1:
                    # Invalid reshape args, skip
                    continue

                if self.all_dims_equivalent(shape, input_shape):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

            elif is_func(node, torch.ops.aten.slice.Tensor):
                input, dim_index, start, end = node.args[:4]
                input_shape = input.meta["val"].shape
                i_dim = input_shape[dim_index]

                if start == 0 and self.dims_equivalent(end, i_dim):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

            elif is_func(node, torch.ops.aten.slice_scatter.default):
                base, view, dim_index, start, end = node.args[:5]
                base_shape = base.meta["val"].shape
                view_shape = view.meta["val"].shape

                view_dim = view_shape[dim_index]

                # Check that view fully covers base and the full view is used
                # (if the view fully covered the base after slicing but was not
                # fully used, we could replace slice_scatter with a simple slice
                # but that's a niche case).
                if (base_shape == view_shape and start == 0
                        and self.dims_equivalent(end, view_dim)):
                    node.replace_all_uses_with(view)
                    graph.erase_node(node)
                    count += 1

        logger.debug("Removed %s no-op reshapes and slices", count)
        self.dump_graph(graph, "after_noop_elimination")
        self.end_and_log()

    def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
                            i_dims: Iterable[Union[int, SymInt]]):
        return all(
            self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))

    def dims_equivalent(self, dim: Union[int, torch.fx.Node],
                        i_dim: Union[int, SymInt]) -> bool:
        """
        This function checks if two dimensions are equivalent.
        :param dim: The dimension arg to reshape/slice
        :param i_dim: The corresponding dimension in the input tensor
        :return: Are the dimensions equivalent?

        There are three cases in which the dimensions are equivalent:
        1. The dimensions are equal (both integers)
        2. The reshape dimension is -1 (i.e. inferred)
        3. The dimensions both correspond to the same SymInt

        While case 2 does not guarantee the dimensions are equal,
        they are equal if all other dimensions are equal.

        In case 3, the reshape dimension is a torch.fx.Node,
        and its value is a SymInt. That value is equal to the
        input dimension.

        """
        # Case 1 and 2
        if dim == i_dim or dim == -1:
            return True
        # Case 3
        return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim