test_liveness_analysis.py 1.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch.nn as nn
import torch
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.fx import ColoTracer, ColoGraphModule


class LinearModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(4, 4)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(4, 4)

    def forward(self, x1, x2):
        x1 = x1 * 2
        x1 = self.linear1(x1)
        x1 = self.relu(x1)
        x1 = self.linear2(x1)
        out = x1 + x2
        return out


def test_liveness_analysis():
    model = LinearModel()
    tracer = ColoTracer()
    graph = tracer.trace(model,
                         meta_args={
                             'x1': torch.rand(4, 4, device='meta'),
                             'x2': torch.rand(4, 4, device='meta')
                         })
    gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)

    graph_analyser = GraphAnalyser(gm)
    liveness_dict = graph_analyser.liveness_analysis()
    stage_count = len(liveness_dict)

    # 8 stages including input and output
    assert stage_count == 8

    # a variable named `relu` must exist
    # and this live var must have inplace = True
    assert liveness_dict[5].all_live_vars.exists('relu')
    relu_var = liveness_dict[5].all_live_vars.get('relu')
    assert relu_var.is_inplace

    # the unique vars must be fewer than the all vars since in-place ops exist
    all_live_vars = liveness_dict[7].all_live_vars
    unique_live_vars = liveness_dict[7].unique_live_vars
    assert len(unique_live_vars) + 1 == len(all_live_vars)


if __name__ == '__main__':
    test_liveness_analysis()