test_convert_models.py 3.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import os
import sys
import unittest
from typing import (Dict)

import numpy as np
import torch
import torch.nn.functional as F
import torchvision

import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script

15
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
16

17
18

class TestModels(unittest.TestCase, ConvertMixin):
19
20
21
22
23
24
25
26
27
28
29
30
    @staticmethod
    def _match_state_dict(current_values, expected_format):
        result = {}
        for k, v in expected_format.items():
            for idx, cv in enumerate(current_values):
                if cv.shape == v.shape:
                    result[k] = cv
                    current_values.pop(idx)
                    break
        return result

    def run_test(self, model, input, check_value=True):
31
        model_ir = self._convert_model(model, input)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
                                                      dict(converted_model.state_dict()))
        converted_model.load_state_dict(converted_state_dict)
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model

    def test_nested_modulelist(self):
        class Net(nn.Module):
            def __init__(self, num_nodes, num_ops_per_node):
                super().__init__()
                self.ops = nn.ModuleList()
                self.num_nodes = num_nodes
                self.num_ops_per_node = num_ops_per_node
                for _ in range(num_nodes):
                    self.ops.append(nn.ModuleList([nn.Linear(16, 16) for __ in range(num_ops_per_node)]))

            def forward(self, x):
                state = x
                for ops in self.ops:
                    for op in ops:
                        state = op(state)
                return state

        model = Net(4, 2)
        x = torch.rand((16, 16), dtype=torch.float)
        self.run_test(model, (x, ))

    def test_append_input_tensor(self):
        from typing import List
        class Net(nn.Module):
            def __init__(self, num_nodes):
                super().__init__()
                self.ops = nn.ModuleList()
                self.num_nodes = num_nodes
                for _ in range(num_nodes):
                    self.ops.append(nn.Linear(16, 16))
            def forward(self, x: List[torch.Tensor]):
                state = x
                for ops in self.ops:
                    state.append(ops(state[-1]))
                return state[-1]

        model = Net(4)
        x = torch.rand((1, 16), dtype=torch.float)
        self.run_test(model, ([x], ))
92
93
94

class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
    pass