reshapes.py 2.92 KB
Newer Older
1
2
3
4
5
6
7
from typing import Union

import torch.fx
from torch import SymInt

from vllm.logger import init_logger

8
9
from .vllm_inductor_pass import VllmInductorPass, is_func

10
11
12
logger = init_logger(__name__)


13
class RedundantReshapesPass(VllmInductorPass):
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    """
    This is an inductor pass that removes redundant reshape operations.
    It is required for RMSNorm-quant fusion to work properly.
    That's because apply_fp8_linear adds a reshape, which is redundant
    in the 2D-case.

    Example graph:

    getitem_1: "f16[s0, 4096]" = ...
    view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
    at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]

    Can be replaced with:
    getitem_1: "f16[s0, 4096]" = ...
    at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
    out: "f8e4m3fn[s0, 4096]" = at[1]
    """

    def __call__(self, graph: torch.fx.Graph):
34
        self.begin()
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        self.dump_graph(graph, "before_reshapes")
        count = 0
        # Remove no-op reshapes/views:
        for node in graph.nodes:
            if is_func(node, torch.ops.aten.reshape.default):
                input, shape = node.args[:2]
                input_shape = input.meta["val"].shape
                if len(shape) != len(input_shape):
                    # Reshape changing rank, skip
                    continue

                if shape.count(-1) > 1:
                    # Invalid reshape args, skip
                    continue

                if all(
                        self.dims_equivalent(s, i_s)
                        for s, i_s in zip(shape, input_shape)):
                    node.replace_all_uses_with(input)
                    graph.erase_node(node)
                    count += 1

57
        logger.debug("Removed %s no-op reshapes", count)
58
59

        self.dump_graph(graph, "after_reshapes")
60
        self.end_and_log()
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def dims_equivalent(self, dim: Union[int, torch.fx.Node],
                        i_dim: Union[int, SymInt]) -> bool:
        """
        This function checks if two dimensions are equivalent.
        :param dim: The dimension arg to reshape
        :param i_dim: The corresponding dimension in the input tensor
        :return: Are the dimensions equivalent?

        There are three cases in which the dimensions are equivalent:
        1. The dimensions are equal (both integers)
        2. The reshape dimension is -1 (i.e. inferred)
        3. The dimensions both correspond to the same SymInt

        While case 2 does not guarantee the dimensions are equal,
        they are equal if all other dimensions are equal.

        In case 3, the reshape dimension is a torch.fx.Node,
        and its value is a SymInt. That value is equal to the
        input dimension.

        """
        # Case 1 and 2
        if dim == i_dim or dim == -1:
            return True
        # Case 3
        return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim