test_op.py 2.09 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
from numpy import allclose, require
2
import torch
Jiarui Fang's avatar
Jiarui Fang committed
3
from colossalai.tensor import ColoTensor
4
5
from copy import deepcopy

6

7
8
9
10
11
12
13
14
15
16
def test_linear():
    in_dim = 4
    out_dim = 5

    fc = torch.nn.Linear(in_dim, out_dim, bias=True)
    fc_ref = deepcopy(fc)

    input_ref = torch.randn(1, in_dim)
    input_tensor = input_ref.clone()

Jiarui Fang's avatar
Jiarui Fang committed
17
18
    sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
    sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    # replace the torch nn.Parameters with ShardedTensor
    delattr(fc, 'weight')
    setattr(fc, 'weight', sharded_weight)
    delattr(fc, 'bias')
    setattr(fc, 'bias', sharded_bias)

    fc.weight.requires_grad = True
    fc.bias.requires_grad = True

    # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
    out = fc(input_tensor)
    loss = out.sum()
    loss.backward()

    out_ref = fc_ref(input_ref)
    loss_ref = out_ref.sum()
    loss_ref.backward()

    assert (loss_ref == loss)
    assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)


# The test case failed
# def test_uniform():
Jiarui Fang's avatar
Jiarui Fang committed
44
#     t = ColoTensor(torch.zeros(3, 5))
45
46
47
#     torch.nn.init.uniform_(t)
#     print(t)

48

49
50
def test_element_wise():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
51
    t = ColoTensor.init_from_torch_tensor(t_ref.clone())
52
53
54
55
56
    assert torch.mean(t) == torch.mean(t_ref)
    assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
    assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))


57
58
59
# Test a function not wrapped by
def test_no_wrap_op():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
60
    t = ColoTensor.init_from_torch_tensor(t_ref.clone())
61
    assert torch.sum(t) == torch.sum(t_ref)
62
    assert torch.sum(input=t) == torch.sum(input=t_ref)
63

64

Jiarui Fang's avatar
Jiarui Fang committed
65
def test_lazy_init_tensor():
66
67
    lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
    assert lazy_t._torch_tensor.numel() == 0
68
69
    assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()

Jiarui Fang's avatar
Jiarui Fang committed
70

Ziyue Jiang's avatar
Ziyue Jiang committed
71
72
73
def check_all():
    test_linear()
    test_element_wise()
74
    test_no_wrap_op()
Ziyue Jiang's avatar
Ziyue Jiang committed
75
76
    test_lazy_init_tensor()

77

Ziyue Jiang's avatar
Ziyue Jiang committed
78
if __name__ == '__main__':
79
    test_lazy_init_tensor()