test_coloproxy.py 579 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from colossalai.fx.proxy import ColoProxy


def test_coloproxy():
    # create a dummy node only for testing purpose
    model = torch.nn.Linear(10, 10)
    gm = torch.fx.symbolic_trace(model)
    node = list(gm.graph.nodes)[0]

    # create proxy
    proxy = ColoProxy(node=node)
13
    proxy.meta_data = torch.empty(4, 2, device='meta')
14
15
16
17
18
19
20
21
22
23

    assert len(proxy) == 4
    assert proxy.shape[0] == 4 and proxy.shape[1] == 2
    assert proxy.dim() == 2
    assert proxy.dtype == torch.float32
    assert proxy.size(0) == 4


if __name__ == '__main__':
    test_coloproxy()