test_coloproxy.py 612 Bytes
Newer Older
1
2
import torch
from colossalai.fx.proxy import ColoProxy
3
import pytest
4
5


6
@pytest.mark.skip
7
8
9
10
11
12
13
14
def test_coloproxy():
    # create a dummy node only for testing purpose
    model = torch.nn.Linear(10, 10)
    gm = torch.fx.symbolic_trace(model)
    node = list(gm.graph.nodes)[0]

    # create proxy
    proxy = ColoProxy(node=node)
15
    proxy.meta_data = torch.empty(4, 2, device='meta')
16
17
18
19
20
21
22
23
24

    assert len(proxy) == 4
    assert proxy.shape[0] == 4 and proxy.shape[1] == 2
    assert proxy.dim() == 2
    assert proxy.dtype == torch.float32
    assert proxy.size(0) == 4


if __name__ == '__main__':
25
    test_coloproxy()