learned_nonlin_test.py 8.6 KB
Newer Older
Daniel Povey's avatar
Daniel Povey committed
1
import random
Daniel Povey's avatar
Daniel Povey committed
2
import torch
3
from torch_learned_nonlin import learned_nonlin
Daniel Povey's avatar
Daniel Povey committed
4
5


6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def test_learned_nonlin_basic():
    for dtype in [torch.float32, torch.float64]:
        B = 2
        C = 4
        T = 10
        x = -2.0 + 0.4 * torch.arange(10, dtype=dtype)
        x = x.reshape(1, 1, 10).repeat(B, C, 1)

        K = 4
        N = K * 2
        params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1)
        print("x = ", x)
        print("params = ", params)
        print("x.shape = ", x.shape)
        y = learned_nonlin(x, params, dim = 1)
        print("y = ", y)




def test_learned_nonlin_zeros():
Daniel Povey's avatar
Daniel Povey committed
27
28
29
30
31
    N = 1
    C = 2
    H = 3
    W = 4
    for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
Daniel Povey's avatar
Daniel Povey committed
32
33
34
        if device == torch.device('cuda:0') and not torch.cuda.is_available():
            print("Warning: torch not available, not testing this part.")
            continue
Daniel Povey's avatar
Daniel Povey committed
35
36
37
38
39
40
        for dtype in [torch.float32, torch.float64]:
            print("device=", device, ", dtype=", dtype)
            input = torch.zeros(N, 2 * C, H, W, device=device, dtype=dtype)
            kH = 5
            kW = 5
            pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
41
            pos_mul = torch.ones(C, kH, kW, device=device, dtype=dtype)
Daniel Povey's avatar
Daniel Povey committed
42
43
44
            input.requires_grad = True
            pos_add.requires_grad = True
            pos_mul.requires_grad = True
Daniel Povey's avatar
Daniel Povey committed
45

Daniel Povey's avatar
Daniel Povey committed
46
            output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
47
            output = learned_nonlin(input, pos_add, pos_mul)
Daniel Povey's avatar
Daniel Povey committed
48
            assert torch.allclose(output, output_ref)
Daniel Povey's avatar
Daniel Povey committed
49

Daniel Povey's avatar
Daniel Povey committed
50
51
52
53
54
            output.sum().backward()
            print("input_grad=", input.grad)
            print("pos_add_grad=", pos_add.grad)
            print("pos_mul_grad=", pos_mul.grad)

Daniel Povey's avatar
Daniel Povey committed
55

56
def test_learned_nonlin_compare():
Daniel Povey's avatar
Daniel Povey committed
57
58
59
60
61
62
63
64
65
    N = 1
    C = 2
    H = 3
    W = 4
    if not torch.cuda.is_available():
        print("Warning: torch not available, not testing this part.")
        return
    for dtype in [torch.float32, torch.float64]:
        print("dtype=", dtype)
Daniel Povey's avatar
Daniel Povey committed
66
        input = torch.randn(N, 2 * C, H, W, dtype=dtype)
Daniel Povey's avatar
Daniel Povey committed
67
        device = torch.device('cuda:0')
68
        input_cuda = input.to(device).detach()
Daniel Povey's avatar
Daniel Povey committed
69
70
71

        kH = 5
        kW = 5
Daniel Povey's avatar
Daniel Povey committed
72
73
74
        pos_add = torch.randn(C, kH, kW, dtype=dtype)
        pos_mul = torch.randn(C, kH, kW, dtype=dtype)

75
76
77
78
79
        pos_add_cuda = pos_add.to(device).detach()
        pos_mul_cuda = pos_mul.to(device).detach()

        for x in [ pos_add, pos_mul, pos_add_cuda, pos_mul_cuda, input, input_cuda ]:
            x.requires_grad = True
Daniel Povey's avatar
Daniel Povey committed
80

81
82
        output = learned_nonlin(input, pos_add, pos_mul)
        output_cuda = learned_nonlin(input_cuda, pos_add_cuda, pos_mul_cuda)
Daniel Povey's avatar
Daniel Povey committed
83
84
        print("output = ", output)
        print("output_cuda = ", output_cuda)
85
86
87
88
89

        output_grad = torch.randn(*output.shape, dtype=dtype)
        output.backward(gradient=output_grad)
        output_cuda.backward(gradient=output_grad.to(device))

Daniel Povey's avatar
Daniel Povey committed
90
91
92
93
94
        diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
        abs = output.abs().sum()
        print("Diff = ", diff, ", abs = ", abs)
        assert torch.allclose(output, output_cuda.to(torch.device('cpu')),
                              atol=1.0e-05)
Daniel Povey's avatar
Daniel Povey committed
95
96


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        for a,b,name in [ (pos_add, pos_add_cuda, 'pos_add'),
                          (pos_mul, pos_mul_cuda, 'pos_mul'),
                          (input, input_cuda, 'input') ]:
            grad = a.grad
            cuda_grad = b.grad.to(torch.device('cpu'))
            diff_abs = (grad - cuda_grad).abs().sum().item()
            sum_abs = (grad + cuda_grad).abs().sum().item()
            print(f"Comparing grad of {name}: diff={diff_abs}, sum={sum_abs}")
            if diff_abs > 1.0e-05 * sum_abs:
                print(f"Error: too much difference in grad of {name}.")
                print("grad = ", grad)
                print("cuda_grad = ", cuda_grad)



112
def test_learned_nonlin_rand_compare():
Daniel Povey's avatar
Daniel Povey committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    for _ in range(30):
        N = random.randint(1, 256)
        C = random.randint(1, 64)
        H = random.randint(1, 128)
        W = random.randint(1, 128)

        while N * C * H * W > 65535:
            if N >= C and N >= H and N >= W:
                N = N // 2
            elif C >= H and C >= W:
                C = C // 2
            elif H >= W:
                H = H // 2
            else:
                W = W // 2


        if not torch.cuda.is_available():
            print("Warning: torch not available, not testing this part.")
            return
        for dtype in [torch.float32, torch.float64]:
            print("dtype=", dtype)
Daniel Povey's avatar
Daniel Povey committed
135
            input = torch.randn(N, 2 * C, H, W, dtype=dtype)
Daniel Povey's avatar
Daniel Povey committed
136
137
138
139
140
141
142
143
144
            device = torch.device('cuda:0')
            input_cuda = input.to(device)

            kH = random.randint(1, 10)
            kW = random.randint(1, 10)
            if kH % 2 == 0:
                kH += 1
            if kW % 2 == 0:
                kW += 1
Daniel Povey's avatar
Daniel Povey committed
145
146
            pos_add = torch.randn(C, kH, kW, dtype=dtype)
            pos_mul = torch.randn(C, kH, kW, dtype=dtype)
Daniel Povey's avatar
Daniel Povey committed
147
148
149
            pos_add_cuda = pos_add.to(device)
            pos_mul_cuda = pos_mul.to(device)

150
151
            output = learned_nonlin(input, pos_add, pos_mul)
            output_cuda = learned_nonlin(input_cuda, pos_add_cuda, pos_mul_cuda)
Daniel Povey's avatar
Daniel Povey committed
152
153

            diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
154
155
            sum_abs = output.abs().sum()
            print("Diff = ", diff, ", abs = ", sum_abs)
Daniel Povey's avatar
Daniel Povey committed
156

157
            if (diff / sum_abs).item() > 0.001:
Daniel Povey's avatar
Daniel Povey committed
158
159
160
                print("output = ", output)
                print("output_cuda = ", output_cuda)
                assert 0, "outputs differ"
161
162
163



164
def test_learned_nonlin_rand_grad():
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    for _ in range(30):
        N = random.randint(1, 256)
        C = random.randint(1, 64)
        H = random.randint(1, 128)
        W = random.randint(1, 128)

        while N * C * H * W > 65535:
            if N >= C and N >= H and N >= W:
                N = N // 2
            elif C >= H and C >= W:
                C = C // 2
            elif H >= W:
                H = H // 2
            else:
                W = W // 2

        for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
            if device == torch.device('cuda:0') and not torch.cuda.is_available():
                print("Warning: torch not available, not testing this part.")
                continue
            for dtype in [torch.float32, torch.float64]:
                print("dtype=", dtype, ", device=", device)
                input = torch.randn(N, 2 * C, H, W, dtype=dtype, device=device)


                kH = random.randint(1, 10)
                kW = random.randint(1, 10)
                if kH % 2 == 0:
                    kH += 1
                if kW % 2 == 0:
                    kW += 1
                pos_add = torch.randn(C, kH, kW, dtype=dtype, device=device)
                pos_mul = torch.randn(C, kH, kW, dtype=dtype, device=device)
                input.requires_grad = True
                pos_add.requires_grad = True
                pos_mul.requires_grad = True

202
                output = learned_nonlin(input, pos_add, pos_mul)
203
204
205
206
207
208
209
                output_grad = torch.randn(N, C, H, W, dtype=dtype, device=device)

                output.backward(gradient=output_grad)

                delta = 1.0e-05
                pos_delta = delta * torch.randn(C, kH, kW, dtype=dtype, device=device)
                pred_change = (pos_delta * pos_add.grad).sum().to('cpu').item()
210
                change = (output_grad * (learned_nonlin(input, pos_add + pos_delta, pos_mul) - output )).sum().to('cpu').item()
211
212
213
214
                print(f"For pos_add: pred_change={pred_change}, change={change}")
                #assert abs(pred_change - change)  < 1.0e-04

                pred_change = (pos_delta * pos_mul.grad).sum().to('cpu').item()
215
                change = (output_grad * (learned_nonlin(input, pos_add, pos_mul + pos_delta) - output )).sum().to('cpu').item()
216
217
218
219
220
                print(f"For pos_mul: pred_change={pred_change}, change={change}")
                #assert abs(pred_change - change) / abs(change) < 1.0e-04

                input_delta = delta * torch.randn(N, 2*C, H, W, dtype=dtype, device=device)
                pred_change = (input_delta * input.grad).sum().to('cpu').item()
221
                change = (output_grad * (learned_nonlin(input + input_delta, pos_add, pos_mul) - output )).sum().to('cpu').item()
222
223
224
225
226
                print(f"For input: pred_change={pred_change}, change={change}")
                #assert abs(pred_change - change) / abs(change) < 1.0e-04


if __name__ == "__main__":
227
228
229
230
231
232
    test_learned_nonlin_basic()
    if False:
        test_learned_nonlin_rand_grad()
        test_learned_nonlin_zeros()
        test_learned_nonlin_compare()
        test_learned_nonlin_rand_compare()