fx_utils.py 2.86 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import operator
5
from collections.abc import Iterable, Iterator
6
from typing import Optional
7
8
9
10
11
12
13
14
15
16

from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload


def is_func(node: fx.Node, target) -> bool:
    return node.op == "call_function" and node.target == target


17
18
19
20
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
    return is_func(node, auto_functionalized) and node.args[0] == op


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
                            op: OpOverload) -> Optional[fx.Node]:
    for node in nodes:
        if node.target == op:
            return node
    return None


# Returns the first specified node with the given op
def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
    node = find_specified_fn_maybe(nodes, op)
    assert node is not None, f"Could not find {op} in nodes {nodes}"
    return node


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
                       op: OpOverload) -> Optional[fx.Node]:
    for node in nodes:
        if is_func(node, auto_functionalized) and node.args[0] == op:  # noqa
            return node
    return None


# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
    node = find_auto_fn_maybe(nodes, op)
    assert node is not None, f"Could not find {op} in nodes {nodes}"
    return node


# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
    for user in node.users:
        if is_func(user, operator.getitem) and user.args[1] == idx:
            return user
    return None


# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
    ret = find_getitem_maybe(node, idx)
    assert ret is not None, f"Could not find getitem {idx} in node {node}"
    return ret
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
    if not op._schema.is_mutable:
        yield from graph.find_nodes(op="call_function", target=op)

    for n in graph.find_nodes(op="call_function", target=auto_functionalized):
        if n.args[0] == op:
            yield n


# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
    assert len(node.users) == 1
    return next(iter(node.users))