network.py 4.92 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import pickle
import re

import torch
9
10
import nni.retiarii.nn.pytorch as nn
from nni.retiarii.nn.pytorch import LayerChoice
Yuge Zhang's avatar
Yuge Zhang committed
11
12
13
14
15
16
17
18
19
20
21
22

from blocks import ShuffleNetBlock, ShuffleXceptionBlock


class ShuffleNetV2OneShot(nn.Module):
    block_keys = [
        'shufflenet_3x3',
        'shufflenet_5x5',
        'shufflenet_7x7',
        'xception_3x3',
    ]

23
24
    def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024,
                 n_classes=1000, affine=False):
Yuge Zhang's avatar
Yuge Zhang committed
25
26
27
28
29
30
31
32
33
34
        super().__init__()

        assert input_size % 32 == 0
        self.stage_blocks = [4, 4, 8, 4]
        self.stage_channels = [64, 160, 320, 640]
        self._input_size = input_size
        self._feature_map_size = input_size
        self._first_conv_channels = first_conv_channels
        self._last_conv_channels = last_conv_channels
        self._n_classes = n_classes
Mingyao Li's avatar
Mingyao Li committed
35
        self._affine = affine
36
        self._layerchoice_count = 0
Yuge Zhang's avatar
Yuge Zhang committed
37
38
39
40

        # building first layer
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
Mingyao Li's avatar
Mingyao Li committed
41
            nn.BatchNorm2d(first_conv_channels, affine=affine),
Yuge Zhang's avatar
Yuge Zhang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
            nn.ReLU(inplace=True),
        )
        self._feature_map_size //= 2

        p_channels = first_conv_channels
        features = []
        for num_blocks, channels in zip(self.stage_blocks, self.stage_channels):
            features.extend(self._make_blocks(num_blocks, p_channels, channels))
            p_channels = channels
        self.features = nn.Sequential(*features)

        self.conv_last = nn.Sequential(
            nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
Mingyao Li's avatar
Mingyao Li committed
55
            nn.BatchNorm2d(last_conv_channels, affine=affine),
Yuge Zhang's avatar
Yuge Zhang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            nn.ReLU(inplace=True),
        )
        self.globalpool = nn.AvgPool2d(self._feature_map_size)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(last_conv_channels, n_classes, bias=False),
        )

        self._initialize_weights()

    def _make_blocks(self, blocks, in_channels, channels):
        result = []
        for i in range(blocks):
            stride = 2 if i == 0 else 1
            inp = in_channels if i == 0 else channels
            oup = channels

            base_mid_channels = channels // 2
            mid_channels = int(base_mid_channels)  # prepare for scale
75
76
            self._layerchoice_count += 1
            choice_block = LayerChoice([
Mingyao Li's avatar
Mingyao Li committed
77
78
79
80
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
                ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
81
            ], label="LayerChoice" + str(self._layerchoice_count))
Yuge Zhang's avatar
Yuge Zhang committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            result.append(choice_block)

            if stride == 2:
                self._feature_map_size //= 2
        return result

    def forward(self, x):
        bs = x.size(0)
        x = self.first_conv(x)
        x = self.features(x)
        x = self.conv_last(x)
        x = self.globalpool(x)

        x = self.dropout(x)
        x = x.contiguous().view(bs, -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
104
                    torch.nn.init.normal_(m.weight, 0, 0.01)
Yuge Zhang's avatar
Yuge Zhang committed
105
                else:
106
                    torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
Yuge Zhang's avatar
Yuge Zhang committed
107
                if m.bias is not None:
108
                    torch.nn.init.constant_(m.bias, 0)
Yuge Zhang's avatar
Yuge Zhang committed
109
110
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
111
                    torch.nn.init.constant_(m.weight, 1)
Yuge Zhang's avatar
Yuge Zhang committed
112
                if m.bias is not None:
113
114
                    torch.nn.init.constant_(m.bias, 0.0001)
                torch.nn.init.constant_(m.running_mean, 0)
Yuge Zhang's avatar
Yuge Zhang committed
115
            elif isinstance(m, nn.BatchNorm1d):
116
                torch.nn.init.constant_(m.weight, 1)
Yuge Zhang's avatar
Yuge Zhang committed
117
                if m.bias is not None:
118
119
                    torch.nn.init.constant_(m.bias, 0.0001)
                torch.nn.init.constant_(m.running_mean, 0)
Yuge Zhang's avatar
Yuge Zhang committed
120
            elif isinstance(m, nn.Linear):
121
                torch.nn.init.normal_(m.weight, 0, 0.01)
Yuge Zhang's avatar
Yuge Zhang committed
122
                if m.bias is not None:
123
                    torch.nn.init.constant_(m.bias, 0)
Yuge Zhang's avatar
Yuge Zhang committed
124
125
126
127


def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
    checkpoint = torch.load(filepath, map_location=torch.device("cpu"))
Yuge Zhang's avatar
Yuge Zhang committed
128
129
    if "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]
Yuge Zhang's avatar
Yuge Zhang committed
130
    result = dict()
Yuge Zhang's avatar
Yuge Zhang committed
131
    for k, v in checkpoint.items():
Yuge Zhang's avatar
Yuge Zhang committed
132
133
134
135
        if k.startswith("module."):
            k = k[len("module."):]
        result[k] = v
    return result