"vscode:/vscode.git/clone" did not exist on "15674793194ca0b5f7729d158e5f6b5eec4a74e5"
test_convert_shape.py 3.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import unittest
import torch

import nni.retiarii.nn.pytorch as nn

from .convert_mixin import ConvertWithShapeMixin


class TestShape(unittest.TestCase, ConvertWithShapeMixin):
    def test_simple_convnet(self):
        class ConvNet(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(3, 1, 3)
                self.relu = nn.ReLU()
                self.pool = nn.MaxPool2d(kernel_size=2)
            def forward(self, x):
                return self.pool(self.relu(self.conv(x)))

        net = ConvNet()
        input = torch.randn((1, 3, 224, 224))
        model_ir = self._convert_model(net, input)

        conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
        relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
        pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
        self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
        self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
        self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
        self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
        self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
        self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]])

    def test_nested_module(self):
        class ConvRelu(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(3, 1, 3)
                self.relu = nn.ReLU()
            def forward(self, x):
                return self.relu(self.conv(x))

        class ConvNet(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = ConvRelu()
                self.pool = nn.MaxPool2d(kernel_size=2)
            def forward(self, x):
                return self.pool(self.conv(x))

        net = ConvNet()
        input = torch.randn((1, 3, 224, 224))
        model_ir = self._convert_model(net, input)

        # check if shape propagation works
        cell_node = model_ir.get_nodes_by_type('_cell')[0]
        self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
        self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])

    def test_layerchoice(self):
        class ConvNet(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.LayerChoice([
                    nn.Conv2d(3, 1, 3),
                    nn.Conv2d(3, 1, 5, padding=1),
                ])
                self.pool = nn.MaxPool2d(kernel_size=2)
            def forward(self, x):
                return self.pool(self.conv(x))

        net = ConvNet()
        input = torch.randn((1, 3, 224, 224))
        model_ir = self._convert_model(net, input)

        # check shape info of each candidates
        conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
        self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
        self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])