test_op.py 3 KB
Newer Older
1
from numpy import allclose
2
import torch
Jiarui Fang's avatar
Jiarui Fang committed
3
from colossalai.tensor import ColoTensor
4
from copy import deepcopy
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from colossalai.utils import get_current_device


def test_layernorm():
    ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
    ln_op_colo = deepcopy(ln_op)

    input_t = torch.randn(3, 2, device=get_current_device())
    input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())

    # prepare colossalai LN
    delattr(ln_op_colo, 'weight')
    weight_clone = ln_op.weight.clone().detach()
    weight_clone.requires_grad = True
    setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))

    output = ln_op(input_t)
    output_colo = ln_op_colo(input_t_colo)

    assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())

    torch.mean(output).backward()
    torch.mean(output_colo).backward()

    assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
30

31

32
33
34
35
36
37
38
39
40
41
def test_linear():
    in_dim = 4
    out_dim = 5

    fc = torch.nn.Linear(in_dim, out_dim, bias=True)
    fc_ref = deepcopy(fc)

    input_ref = torch.randn(1, in_dim)
    input_tensor = input_ref.clone()

Jiarui Fang's avatar
Jiarui Fang committed
42
43
    sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
    sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    # replace the torch nn.Parameters with ShardedTensor
    delattr(fc, 'weight')
    setattr(fc, 'weight', sharded_weight)
    delattr(fc, 'bias')
    setattr(fc, 'bias', sharded_bias)

    fc.weight.requires_grad = True
    fc.bias.requires_grad = True

    # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
    out = fc(input_tensor)
    loss = out.sum()
    loss.backward()

    out_ref = fc_ref(input_ref)
    loss_ref = out_ref.sum()
    loss_ref.backward()

    assert (loss_ref == loss)
    assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)


# The test case failed
# def test_uniform():
Jiarui Fang's avatar
Jiarui Fang committed
69
#     t = ColoTensor(torch.zeros(3, 5))
70
71
72
#     torch.nn.init.uniform_(t)
#     print(t)

73

74
75
def test_element_wise():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
76
    t = ColoTensor.init_from_torch_tensor(t_ref.clone())
77
    assert torch.mean(t) == torch.mean(t_ref)
78
79
    assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
    assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
80
81


82
83
84
# Test a function not wrapped by
def test_no_wrap_op():
    t_ref = torch.randn(3, 5)
Jiarui Fang's avatar
Jiarui Fang committed
85
    t = ColoTensor.init_from_torch_tensor(t_ref.clone())
86
    assert torch.sum(t) == torch.sum(t_ref)
87
    assert torch.sum(input=t) == torch.sum(input=t_ref)
88

89

Jiarui Fang's avatar
Jiarui Fang committed
90
def test_lazy_init_tensor():
91
92
    lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
    assert lazy_t._torch_tensor.numel() == 0
93
94
    assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()

Jiarui Fang's avatar
Jiarui Fang committed
95

Ziyue Jiang's avatar
Ziyue Jiang committed
96
97
98
def check_all():
    test_linear()
    test_element_wise()
99
    test_no_wrap_op()
Ziyue Jiang's avatar
Ziyue Jiang committed
100
101
    test_lazy_init_tensor()

102

Ziyue Jiang's avatar
Ziyue Jiang committed
103
if __name__ == '__main__':
104
105
    # test_lazy_init_ptensor()
    test_layernorm()