test_op.py 1.65 KB
Newer Older
1
2
from numpy import allclose
import torch
Jiarui Fang's avatar
Jiarui Fang committed
3
from colossalai.tensor import ColoTensor
4
5
6
7
8
9
10
11
12
13
14
15
16
from copy import deepcopy


def test_linear():
    in_dim = 4
    out_dim = 5

    fc = torch.nn.Linear(in_dim, out_dim, bias=True)
    fc_ref = deepcopy(fc)

    input_ref = torch.randn(1, in_dim)
    input_tensor = input_ref.clone()

Jiarui Fang's avatar
Jiarui Fang committed
17
18
    sharded_weight = ColoTensor(fc_ref.weight)
    sharded_bias = ColoTensor(fc_ref.bias)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    # replace the torch nn.Parameters with ShardedTensor
    delattr(fc, 'weight')
    setattr(fc, 'weight', sharded_weight)
    delattr(fc, 'bias')
    setattr(fc, 'bias', sharded_bias)

    fc.weight.requires_grad = True
    fc.bias.requires_grad = True

    # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
    out = fc(input_tensor)
    loss = out.sum()
    loss.backward()

    out_ref = fc_ref(input_ref)
    loss_ref = out_ref.sum()
    loss_ref.backward()

    assert (loss_ref == loss)
    assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)


# The test case failed
# def test_uniform():
Jiarui Fang's avatar
Jiarui Fang committed
44
#     t = ColoTensor(torch.zeros(3, 5))
45
46
47
48
49
50
#     torch.nn.init.uniform_(t)
#     print(t)


def test_element_wise():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
51
    t = ColoTensor(t_ref.clone())
52
53
54
55
56
    assert torch.mean(t) == torch.mean(t_ref)
    assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
    assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))


57
58
59
60
61
62
63
# Test a function not wrapped by
def test_no_wrap_op():
    t_ref = torch.randn(3, 5)
    t = ColoTensor(t_ref.clone())
    assert torch.sum(t) == torch.sum(t_ref)


64
if __name__ == '__main__':
65
    test_no_wrap_op()
66
    # test_element_wise()