test_convert_basic.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import os
import sys
import unittest

import numpy as np
import torch
import torch.nn.functional as F
import torchvision

import nni.retiarii.nn.pytorch as nn
11
from nni.retiarii import basic_unit
12
13

from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
14
15
16
17
from nni.retiarii.codegen import model_to_pytorch_script

# following pytorch v1.7.1

18
class TestConvert(unittest.TestCase, ConvertMixin):
19
20
21
22
23
24
25
26
27
28
29
30
    @staticmethod
    def _match_state_dict(current_values, expected_format):
        result = {}
        for k, v in expected_format.items():
            for idx, cv in enumerate(current_values):
                if cv.shape == v.shape:
                    result[k] = cv
                    current_values.pop(idx)
                    break
        return result

    def checkExportImport(self, model, input, check_value=True):
31
        model_ir = self._convert_model(model, input)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
                                                      dict(converted_model.state_dict()))
        converted_model.load_state_dict(converted_state_dict)
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            self.assertEqual(len(converted_output), len(expected_output))
            for a, b in zip(converted_output, expected_output):
                if hasattr(a, 'dtype') and a.dtype == torch.bool:
                    self.assertEqual((a ^ b), False)
                elif isinstance((a - b), int):
                    self.assertEqual((a - b), 0)
                else:
                    self.assertLess((a - b).abs().max().item(), 1E-4)
        return converted_model

    # skip torch.Tensor.new_tensor as it is not supported by jit

    def test_basic_new_full(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                # requires_grad is not supported by jit
                # aten::new_full(Tensor self, int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
                # Keyword argument requires_grad unknown.
                out = x.new_full((3, 4), 3.141592, dtype=torch.float32, device=torch.device('cpu'))
                return out
        self.checkExportImport(SimpleOp(), (torch.ones((2,), dtype=torch.float64), ))

    def test_basic_new_empty(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out = x.new_empty((2, 3), dtype=torch.int8, device=torch.device('cpu'))
                return out
        self.checkExportImport(SimpleOp(), (torch.ones(()), ), check_value=False)

    # skip torch.Tensor.new_ones as it is not supported by jit

    # requires_grad=False is not supported by jit
    def test_basic_new_zeros(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out = x.new_zeros((2, 3))
                return out
        self.checkExportImport(SimpleOp(), (torch.tensor((), dtype=torch.int32), ))

    def test_basic_is_cuda(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                return torch.tensor([x.is_cuda], dtype=torch.bool, device=torch.device('cpu'))
        self.checkExportImport(SimpleOp(), (torch.tensor((), dtype=torch.int32), ))

    # is_quantized
    # is_meta
    # device
    # grad
    # ndim
    # T
    # real
    # imag

    def test_basic_abs(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out1 = x.abs()
                out11 = x.absolute()
                out2 = torch.abs(x)
                #out3 = x.abs_()
                #out33 = x.absolute_()
                return out1, out11, out2#, out3, out33
        self.checkExportImport(SimpleOp(), (torch.tensor([-1, -2, 3]), ))

    # TODO: topological sort should be improved
    #def forward(self, x__1):
    #    __Acos2 = x__1.acos()
    #    __Acos_3 = x__1.acos_()
    #    __Acos1 = x__1.acos()
    #    __TupleConstruct4 = (__Acos1,__Acos2,__Acos_3)
    #    return __TupleConstruct4
    def test_basic_acos_asin_atan(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y):
                out1 = x.acos()
                out2 = torch.acos(x)
                # TODO: add back this line
                #out = x.acos_()
                out3 = x.asin()
                out4 = torch.asin(x)
                out5 = x.atan()
                out6 = torch.atan(x)
                out7 = x.atan2(y)
                out8 = torch.atan2(x, y)
                return out1, out2, out3, out4, out5, out6, out7, out8#, out
        self.checkExportImport(SimpleOp(), (torch.tensor([-1.0, -0.5, 0.2]), torch.tensor([1.0, 0.6, -0.3]), ))

    # arccos is not supported by jit

    def test_basic_add(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                t = torch.tensor([-1.0, -0.5, 0.2])
                out1 = x.add(t)
                out2 = x.add(t, alpha=2)
                #out3 = x.add_(t)
                return out1, out2#, out3
        self.checkExportImport(SimpleOp(), (torch.tensor([-1.0, -0.5, 0.2]), ))

    def test_basic_addbmm(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y, z, m):
                out1 = x.addbmm(y, z, beta=2, alpha=3)
                out2 = torch.addbmm(x, y, z, beta=2, alpha=3)
                #out3 = x.addbmm_(y, z, beta=2, alpha=3)
                out3 = m.baddbmm(y, z, beta=2, alpha=3)
                out4 = torch.baddbmm(m, y, z, beta=2, alpha=3)
                out5 = torch.bmm(y, z) # deterministic is not supported by jit
                return out1, out2, out3, out4, out5
        self.checkExportImport(SimpleOp(), (torch.randn(3, 5), torch.randn(10, 3, 4), torch.randn(10, 4, 5), torch.randn(10, 3, 5), ))

    def test_basic_addcdiv(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y, z):
                out1 = x.addcdiv(y, z, value=2)
                out2 = torch.addcdiv(x, y, z, value=2)
                # addcdiv_
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), ))

    def test_basic_addcmul(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y, z):
                out1 = x.addcmul(y, z, value=0.1)
                out2 = torch.addcmul(x, y, z, value=0.1)
                # addcmul_
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), ))

    def test_basic_addmm(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y, z):
                out1 = x.addmm(y, z, beta=0.1, alpha=0.2)
                out2 = torch.addmm(x, y, z, beta=0.1, alpha=0.2)
                # addmm_
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3), ))

    def test_basic_addmv(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y, z):
                out1 = x.addmv(y, z, beta=0.1, alpha=0.2)
                out2 = torch.addmv(x, y, z, beta=0.1, alpha=0.2)
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.randn(2), torch.randn(2, 3), torch.randn(3), ))
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def test_basic_addr(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y, z):
                out1 = x.addr(y, z, beta=2, alpha=3)
                out2 = torch.addr(x, y, z, beta=2, alpha=3)
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.zeros(3, 2), torch.arange(1., 4.), torch.arange(1., 3.), ))

    def test_basic_allclose(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y):
                out1 = x.allclose(y, rtol=1e-05, atol=1e-08, equal_nan=False)
                out2 = torch.allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False)
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08]), ))
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    def test_basic_angle(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out1 = x.angle()
                out2 = torch.angle(x)
                return out1, out2
        self.checkExportImport(SimpleOp(), (torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]), ))

    # skip apply_(callable) for now

    def test_basic_argmax_argmin(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out1 = x.argmax()
                out2 = torch.argmax(x)
                out3 = x.argmax(dim=1)
                out4 = torch.argmax(x, dim=1)
                out5 = x.argmax(dim=1, keepdim=True)
                o1 = x.argmin()
                o2 = torch.argmin(x)
                o3 = x.argmin(dim=1)
                o4 = x.argmin(dim=1, keepdim=True)
                return out1, out2, out3, out4, out5, o1, o2, o3, o4
        self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))
232

233
234
235
236
237
238
239
240
241
242
243
    def test_basic_argsort(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out1 = x.argsort()
                out2 = x.argsort(dim=1)
                out3 = x.argsort(dim=1, descending=True)
                out4 = torch.argsort(x, dim=1, descending=True)
                return out1, out2, out3, out4
        self.checkExportImport(SimpleOp(), (torch.randn(4, 4), ))

    # skip backward(gradient=None, retain_graph=None, create_graph=False)
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    def test_basic_bernoulli(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                # generator=torch.Generator() is not supported by jit
                out = x.bernoulli()
                return out
        self.checkExportImport(SimpleOp(), (torch.ones(3, 3), ))

    # bfloat16/bool/byte/char is not supported by jit

    def test_basic_bincount(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y):
                out1 = x.bincount()
                out2 = torch.bincount(x)
                out3 = x.bincount(weights=y)
                out4 = x.bincount(weights=y, minlength=2)
                return out1, out2, out3, out4
        self.checkExportImport(SimpleOp(), (torch.randint(0, 8, (5,), dtype=torch.int64), torch.linspace(0, 1, steps=5), ))
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    def test_basic_bitwise(self):
        class SimpleOp(nn.Module):
            def forward(self, x, y):
                out1 = x.bitwise_not()
                out2 = x.bitwise_and(y)
                out3 = x.bitwise_or(y)
                out4 = x.bitwise_xor(y)
                return out1, out2, out3, out4
        self.checkExportImport(SimpleOp(), (torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8), ))

    # cauchy_ is not supported yet

    def test_ceil(self):
        class SimpleOp(nn.Module):
            def forward(self, x):
                out1 = x.ceil()
                return out1
282
283
284
285
286
        self.checkExportImport(SimpleOp(), (torch.randn(4), ))


class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
    pass