network.py 4.99 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import pickle
import re

import torch
9
10
import nni.retiarii.nn.pytorch as nn
from nni.retiarii.nn.pytorch import LayerChoice
11
from nni.retiarii.serializer import model_wrapper
Yuge Zhang's avatar
Yuge Zhang committed
12
13
14
15

from blocks import ShuffleNetBlock, ShuffleXceptionBlock


16
@model_wrapper
Yuge Zhang's avatar
Yuge Zhang committed
17
18
19
20
21
22
23
24
class ShuffleNetV2OneShot(nn.Module):
    block_keys = [
        'shufflenet_3x3',
        'shufflenet_5x5',
        'shufflenet_7x7',
        'xception_3x3',
    ]

25
26
    def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024,
                 n_classes=1000, affine=False):
Yuge Zhang's avatar
Yuge Zhang committed
27
28
29
30
31
32
33
34
35
36
        super().__init__()

        assert input_size % 32 == 0
        self.stage_blocks = [4, 4, 8, 4]
        self.stage_channels = [64, 160, 320, 640]
        self._input_size = input_size
        self._feature_map_size = input_size
        self._first_conv_channels = first_conv_channels
        self._last_conv_channels = last_conv_channels
        self._n_classes = n_classes
Mingyao Li's avatar
Mingyao Li committed
37
        self._affine = affine
38
        self._layerchoice_count = 0
Yuge Zhang's avatar
Yuge Zhang committed
39
40
41
42

        # building first layer
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
Mingyao Li's avatar
Mingyao Li committed
43
            nn.BatchNorm2d(first_conv_channels, affine=affine),
Yuge Zhang's avatar
Yuge Zhang committed
44
45
46
47
48
49
50
51
52
53
54
55
56
            nn.ReLU(inplace=True),
        )
        self._feature_map_size //= 2

        p_channels = first_conv_channels
        features = []
        for num_blocks, channels in zip(self.stage_blocks, self.stage_channels):
            features.extend(self._make_blocks(num_blocks, p_channels, channels))
            p_channels = channels
        self.features = nn.Sequential(*features)

        self.conv_last = nn.Sequential(
            nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
Mingyao Li's avatar
Mingyao Li committed
57
            nn.BatchNorm2d(last_conv_channels, affine=affine),
Yuge Zhang's avatar
Yuge Zhang committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            nn.ReLU(inplace=True),
        )
        self.globalpool = nn.AvgPool2d(self._feature_map_size)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(last_conv_channels, n_classes, bias=False),
        )

        self._initialize_weights()

    def _make_blocks(self, blocks, in_channels, channels):
        result = []
        for i in range(blocks):
            stride = 2 if i == 0 else 1
            inp = in_channels if i == 0 else channels
            oup = channels

            base_mid_channels = channels // 2
            mid_channels = int(base_mid_channels)  # prepare for scale
77
78
            self._layerchoice_count += 1
            choice_block = LayerChoice([
Mingyao Li's avatar
Mingyao Li committed
79
80
81
82
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
                ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
                ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
83
            ], label="LayerChoice" + str(self._layerchoice_count))
Yuge Zhang's avatar
Yuge Zhang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            result.append(choice_block)

            if stride == 2:
                self._feature_map_size //= 2
        return result

    def forward(self, x):
        bs = x.size(0)
        x = self.first_conv(x)
        x = self.features(x)
        x = self.conv_last(x)
        x = self.globalpool(x)

        x = self.dropout(x)
        x = x.contiguous().view(bs, -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
106
                    torch.nn.init.normal_(m.weight, 0, 0.01)
Yuge Zhang's avatar
Yuge Zhang committed
107
                else:
108
                    torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
Yuge Zhang's avatar
Yuge Zhang committed
109
                if m.bias is not None:
110
                    torch.nn.init.constant_(m.bias, 0)
Yuge Zhang's avatar
Yuge Zhang committed
111
112
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
113
                    torch.nn.init.constant_(m.weight, 1)
Yuge Zhang's avatar
Yuge Zhang committed
114
                if m.bias is not None:
115
116
                    torch.nn.init.constant_(m.bias, 0.0001)
                torch.nn.init.constant_(m.running_mean, 0)
Yuge Zhang's avatar
Yuge Zhang committed
117
            elif isinstance(m, nn.BatchNorm1d):
118
                torch.nn.init.constant_(m.weight, 1)
Yuge Zhang's avatar
Yuge Zhang committed
119
                if m.bias is not None:
120
121
                    torch.nn.init.constant_(m.bias, 0.0001)
                torch.nn.init.constant_(m.running_mean, 0)
Yuge Zhang's avatar
Yuge Zhang committed
122
            elif isinstance(m, nn.Linear):
123
                torch.nn.init.normal_(m.weight, 0, 0.01)
Yuge Zhang's avatar
Yuge Zhang committed
124
                if m.bias is not None:
125
                    torch.nn.init.constant_(m.bias, 0)
Yuge Zhang's avatar
Yuge Zhang committed
126
127
128
129


def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
    checkpoint = torch.load(filepath, map_location=torch.device("cpu"))
Yuge Zhang's avatar
Yuge Zhang committed
130
131
    if "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]
Yuge Zhang's avatar
Yuge Zhang committed
132
    result = dict()
Yuge Zhang's avatar
Yuge Zhang committed
133
    for k, v in checkpoint.items():
Yuge Zhang's avatar
Yuge Zhang committed
134
135
136
137
        if k.startswith("module."):
            k = k[len("module."):]
        result[k] = v
    return result