dataflow.py 6.14 KB
Newer Older
1
from dataclasses import dataclass, field
2
from enum import Enum
3
from typing import Dict, List
4

5
from torch.fx import Graph, Node
6
7

from .._compatibility import compatibility
8
from .memory_utils import activation_size, is_inplace
9
10


11
class Phase(Enum):
12
    FORWARD = 0
13
14
    BACKWARD = 1
    PLACEHOLDER = 2
15
16


17
@compatibility(is_backward_compatible=True)
18
19
20
21
22
23
24
25
26
@dataclass
class GraphInfo:
    """
    GraphInfo is a dataclass for MetaInfo, which measures
    the execution memory cost and FLOPs with `MetaTensor`.
    The dataflow analysis is conducted on a single node of the FX graph.
    ============================================================================
                            -------------------------------
                            |            Node             |
27
    [fwd_in] are       ---> | [fwd_in]          [bwd_out] |    <----- [bwd_out] is marks the memory for `grad_out`.
28
29
30
    placeholders saved for  |     | \__________     |     |
    backward.               |     |            \    |     |
                            | [fwd_tmp] ------> [bwd_tmp] |    <-----
31
                            |     |  \_________     |     |    [bwd_tmp] marks the peak memory
32
33
                            |    / \           \    |     |    in backward pass.
    [x] is not counted ---> | [x]  [fwd_tmp] -> [bwd_tmp] |    <-----
34
35
36
37
    in [fwd_tmp] because    |          |  \_____    |     |
    it is not saved for     |          |        \   |     |
    backward.               |      [fwd_out]     \  |     |    <----- [fwd_out] is [fwd_in] for the next node.
                            -------------------------------
38
39
    ============================================================================
    Attributes:
40
41
        fwd_flop (int): The forward FLOPs of a certain node.
        fwd_time (float): The real forward time (s) of a certain node.
42
        bwd_flop (int): The backward FLOPs of a certain node.
43
        bwd_time (float): The real backward time (s) of a certain node.
44
        save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
45
46
47
        fwd_in (List): See the above illustration.
        fwd_tmp (List): See the above illustration.
        fwd_out (List): See the above illustration.
48
        fwd_mem_tmp (int): See the above illustration.
49
        fwd_mem_out (int): See the above illustration.
50
51
52
        bwd_mem_tmp (int): See the above illustration.
        bwd_mem_out (int): See the above illustration.
    """
53
54
55

    # TODO(super-dainiu): removed redundant items, currently all of them are necessary for development

56
    fwd_flop: int = 0
57
    fwd_time: float = 0.0
58
    bwd_flop: int = 0
59
    bwd_time: float = 0.0
60
    save_fwd_in: bool = False
61
62
63
    fwd_in: List = field(default_factory=list)
    fwd_tmp: List = field(default_factory=list)
    fwd_out: List = field(default_factory=list)
64
    fwd_mem_tmp: int = 0
65
    fwd_mem_out: int = 0
66
67
68
69
    bwd_mem_tmp: int = 0
    bwd_mem_out: int = 0


70
def is_phase(n: Node, phase: Phase) -> bool:
71
72
    assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
    return n.meta["phase"] == phase
73
74


75
@compatibility(is_backward_compatible=False)
76
77
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
    """Analyze the autograd node dependencies and find out the memory usage.
78
    Basically the input graph should have all nodes marked for keyword `phase`.
79
80
81
    Nodes should have attribute `out` indicating the output of each node.
    ============================================================================
    Placeholder ---->   p           o     <---- We need to keep track of grad out
82
                        |\________  |
83
84
85
86
87
88
                        ↓         ↘|
                        f --------> b
                        |\ \_____   ↑
                        | \      ↘ /
                        f  f ----> b      <---- Not every forward result needs to be saved for backward
                        |   \____  ↑
89
                         ↘      ↘|
90
91
92
                           f ----> b      <---- Backward can be freed as soon as it is required no more.
                             ↘ ↗
                               l
93
    =============================================================================
94
    Args:
95
        graph (Graph): The autograd graph with nodes marked for keyword `phase`.
96
97
98
99
100
101

    Returns:
        graph_info (GraphInfo): Meta information for the dataflow.
    """

    def _peak_memory(deps: Dict[Node, int]):
102
        peak_mem = 0
103
        for k, v in deps.items():
104
            if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
105
106
107
                peak_mem += activation_size(k.meta["saved_tensor"])
            if v <= float("-inf") and is_phase(k, Phase.FORWARD):
                peak_mem -= activation_size(k.meta["saved_tensor"])
108
        return peak_mem
109
110
111
112
113
114
115

    # deps is used to track all the memory dependencies of the graph.
    deps = {}
    graph_info = GraphInfo()

    for n in graph.nodes:
        n: Node
116
117
118
119
120
121
122
123
124
        deps[n] = len(n.users)
        # A forward tensor who is marked `save` but is also
        # an input to `Phase.FORWARD` should be saved during forward.
        # If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
        # Any `fwd_mem_in` should be kept in memory even this function
        # is checkpointed.
        # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
        # the node, `fwd_mem_tmp` can be freed.
        if is_phase(n, Phase.PLACEHOLDER):
125
            graph_info.fwd_in += n.meta["saved_tensor"]
126
        if is_phase(n, Phase.FORWARD):
127
            graph_info.fwd_tmp += n.meta["saved_tensor"]
128
        elif is_phase(n, Phase.BACKWARD):
129
130
131
            if len(n.users):
                graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
            else:
132
                # TODO: some of the bwd_mem_out might be model parameters.
133
                # basically a backward node without user is a `grad_out` node
134
                graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
135
136
137
138
        for input_n in n.all_input_nodes:
            if input_n in deps:
                deps[input_n] -= 1
                if deps[input_n] <= 0:
139
                    deps[input_n] = float("-inf")
140
    return graph_info