fx_utils.py 2.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import operator
5
from collections.abc import Iterable, Iterator
6
7
8

from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
9
from torch._ops import OpOverload, OpOverloadPacket
10
from torch.fx.node import Target
11
12


13
14
def is_func(node: fx.Node, target: Target) -> bool:
    return bool(node.op == "call_function" and node.target == target)
15
16


17
18
19
20
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
    return is_func(node, auto_functionalized) and node.args[0] == op


21
# Returns the first auto_functionalized node with the given op (if it exists)
22
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    for node in nodes:
        if is_func(node, auto_functionalized) and node.args[0] == op:  # noqa
            return node
    return None


# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
    node = find_auto_fn_maybe(nodes, op)
    assert node is not None, f"Could not find {op} in nodes {nodes}"
    return node


# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
38
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
39
40
41
42
43
44
45
46
47
48
49
    for user in node.users:
        if is_func(user, operator.getitem) and user.args[1] == idx:
            return user
    return None


# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
    ret = find_getitem_maybe(node, idx)
    assert ret is not None, f"Could not find getitem {idx} in node {node}"
    return ret
50
51
52


# An auto-functionalization-aware utility for finding nodes with a specific op
53
54
55
56
57
58
59
60
61
62
63
# Also handles op overload packets and finds all overloads
def find_op_nodes(
    op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
    if isinstance(op, OpOverloadPacket):
        for overload in op.overloads():
            overload_op = getattr(op, overload)
            yield from find_op_nodes(overload_op, graph)
        return

    assert isinstance(op, OpOverload)
64
65

    yield from graph.find_nodes(op="call_function", target=op)
66
67
68
69
70
71
72
73
74
75
76
77

    for n in graph.find_nodes(op="call_function", target=auto_functionalized):
        if n.args[0] == op:
            yield n


# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
    assert len(node.users) == 1
    return next(iter(node.users))