test_op.py 2.06 KB
Newer Older
Jiarui Fang's avatar
Jiarui Fang committed
1
from numpy import allclose, require
2
import torch
Jiarui Fang's avatar
Jiarui Fang committed
3
from colossalai.tensor import ColoTensor
4
5
6
7
8
9
10
11
12
13
14
15
from copy import deepcopy

def test_linear():
    in_dim = 4
    out_dim = 5

    fc = torch.nn.Linear(in_dim, out_dim, bias=True)
    fc_ref = deepcopy(fc)

    input_ref = torch.randn(1, in_dim)
    input_tensor = input_ref.clone()

Jiarui Fang's avatar
Jiarui Fang committed
16
17
    sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
    sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    # replace the torch nn.Parameters with ShardedTensor
    delattr(fc, 'weight')
    setattr(fc, 'weight', sharded_weight)
    delattr(fc, 'bias')
    setattr(fc, 'bias', sharded_bias)

    fc.weight.requires_grad = True
    fc.bias.requires_grad = True

    # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
    out = fc(input_tensor)
    loss = out.sum()
    loss.backward()

    out_ref = fc_ref(input_ref)
    loss_ref = out_ref.sum()
    loss_ref.backward()

    assert (loss_ref == loss)
    assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)


# The test case failed
# def test_uniform():
Jiarui Fang's avatar
Jiarui Fang committed
43
#     t = ColoTensor(torch.zeros(3, 5))
44
45
46
47
48
#     torch.nn.init.uniform_(t)
#     print(t)

def test_element_wise():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
49
    t = ColoTensor.init_from_torch_tensor(t_ref.clone())
50
51
52
53
54
    assert torch.mean(t) == torch.mean(t_ref)
    assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
    assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))


55
56
57
# Test a function not wrapped by
def test_no_wrap_op():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
58
    t = ColoTensor.init_from_torch_tensor(t_ref.clone())
59
    assert torch.sum(t) == torch.sum(t_ref)
60
    assert torch.sum(input=t) == torch.sum(input=t_ref)
61

Jiarui Fang's avatar
Jiarui Fang committed
62
63
64
65
66
def test_lazy_init_tensor():
    lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
    assert lazy_t._torch_tensor == None
    assert lazy_t.torch_tensor().numel() == 6

Ziyue Jiang's avatar
Ziyue Jiang committed
67
68
69
def check_all():
    test_linear()
    test_element_wise()
70
    test_no_wrap_op()
Ziyue Jiang's avatar
Ziyue Jiang committed
71
72
73
74
    test_lazy_init_tensor()

if __name__ == '__main__':
    check_all()