proxylessnas.py 19.4 KB
Newer Older
1
2
3
4
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import math
5
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
6
7

import torch
8
9
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
10

11
12
13
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight

14

15
16
17
18
19
20
21
22
23
24
25
@overload
def make_divisible(v: Union[int, float], divisor, min_val=None) -> int:
    ...


@overload
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float]], divisor, min_val=None) -> nn.ChoiceOf[int]:
    ...


def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float], int, float], divisor, min_val=None) -> nn.MaybeChoice[int]:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_val is None:
        min_val = divisor
    # This should work for both value choices and constants.
    new_v = nn.ValueChoice.max(min_val, round(v + divisor // 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    return nn.ValueChoice.condition(new_v < 0.9 * v, new_v + divisor, new_v)


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def simplify_sequential(sequentials: List[nn.Module]) -> Iterator[nn.Module]:
    """
    Flatten the sequential blocks so that the hierarchy looks better.
    Eliminate identity modules automatically.
    """
    for module in sequentials:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                # no recursive expansion
                if not isinstance(submodule, nn.Identity):
                    yield submodule
        else:
            if not isinstance(module, nn.Identity):
                yield module


56
57
58
59
60
61
62
class ConvBNReLU(nn.Sequential):
    """
    The template for a conv-bn-relu block.
    """

    def __init__(
        self,
63
64
65
        in_channels: nn.MaybeChoice[int],
        out_channels: nn.MaybeChoice[int],
        kernel_size: nn.MaybeChoice[int] = 3,
66
        stride: int = 1,
67
68
        groups: nn.MaybeChoice[int] = 1,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
69
70
71
72
73
74
75
76
        activation_layer: Optional[Callable[..., nn.Module]] = None,
        dilation: int = 1,
    ) -> None:
        padding = (kernel_size - 1) // 2 * dilation
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU6
77
78
79
80
81
        # If no normalization is used, set bias to True
        # https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L194
        norm = norm_layer(cast(int, out_channels))
        no_normalization = isinstance(norm, nn.Identity)
        blocks: List[nn.Module] = [
82
83
84
85
86
87
88
89
            nn.Conv2d(
                cast(int, in_channels),
                cast(int, out_channels),
                cast(int, kernel_size),
                stride,
                cast(int, padding),
                dilation=dilation,
                groups=cast(int, groups),
90
                bias=no_normalization
91
            ),
92
93
94
95
96
            # Normalization, regardless of batchnorm or identity
            norm,
            # One pytorch implementation as an SE here, to faithfully reproduce paper
            # We follow a more accepted approach to put SE outside
            # Reference: https://github.com/d-li14/mobilenetv3.pytorch/issues/18
97
            activation_layer(inplace=True)
98
99
100
        ]

        super().__init__(*simplify_sequential(blocks))
101
102


103
class DepthwiseSeparableConv(nn.Sequential):
104
105
106
    """
    In the original MobileNetV2 implementation, this is InvertedResidual when expand ratio = 1.
    Residual connection is added if input and output shape are the same.
107
108
109
110
111
112

    References:

    - https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L90
    - https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L433
    - https://github.com/ultmaster/AceNAS/blob/46c8895f/searchspace/proxylessnas/utils.py#L100
113
114
115
116
    """

    def __init__(
        self,
117
118
119
        in_channels: nn.MaybeChoice[int],
        out_channels: nn.MaybeChoice[int],
        kernel_size: nn.MaybeChoice[int] = 3,
120
        stride: int = 1,
121
        squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
122
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
123
124
        activation_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
125
        blocks = [
126
127
128
            # dw
            ConvBNReLU(in_channels, in_channels, stride=stride, kernel_size=kernel_size, groups=in_channels,
                       norm_layer=norm_layer, activation_layer=activation_layer),
129
130
            # optional se
            squeeze_excite(in_channels, in_channels) if squeeze_excite else nn.Identity(),
131
132
            # pw-linear
            ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)
133
134
        ]
        super().__init__(*simplify_sequential(blocks))
135
136
        # NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
        self.has_skip = stride == 1 and in_channels is out_channels
137
138

    def forward(self, x: torch.Tensor) -> torch.Tensor:
139
        if self.has_skip:
140
141
142
143
144
145
146
147
148
            return x + super().forward(x)
        else:
            return super().forward(x)


class InvertedResidual(nn.Sequential):
    """
    An Inverted Residual Block, sometimes called an MBConv Block, is a type of residual block used for image models
    that uses an inverted structure for efficiency reasons.
149

150
151
152
153
154
155
    It was originally proposed for the `MobileNetV2 <https://arxiv.org/abs/1801.04381>`__ CNN architecture.
    It has since been reused for several mobile-optimized CNNs.
    It follows a narrow -> wide -> narrow approach, hence the inversion.
    It first widens with a 1x1 convolution, then uses a 3x3 depthwise convolution (which greatly reduces the number of parameters),
    then a 1x1 convolution is used to reduce the number of channels so input and output can be added.

156
157
158
159
    This implementation is sort of a mixture between:

    - https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L453
    - https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L134
160
161
162
163
    """

    def __init__(
        self,
164
165
166
167
        in_channels: nn.MaybeChoice[int],
        out_channels: nn.MaybeChoice[int],
        expand_ratio: nn.MaybeChoice[float],
        kernel_size: nn.MaybeChoice[int] = 3,
168
        stride: int = 1,
169
        squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
170
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
171
172
173
174
175
176
177
        activation_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        self.stride = stride
        self.out_channels = out_channels
        assert stride in [1, 2]

178
        hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8))
179

180
181
        # NOTE: this equivalence check (==) does NOT work for ValueChoice, need to use "is"
        self.has_skip = stride == 1 and in_channels is out_channels
182
183
184
185
186
187
188
189
190
191

        layers: List[nn.Module] = [
            # point-wise convolution
            # NOTE: some paper omit this point-wise convolution when stride = 1.
            # In our implementation, if this pw convolution is intended to be omitted,
            # please use SepConv instead.
            ConvBNReLU(in_channels, hidden_ch, kernel_size=1,
                       norm_layer=norm_layer, activation_layer=activation_layer),
            # depth-wise
            ConvBNReLU(hidden_ch, hidden_ch, stride=stride, kernel_size=kernel_size, groups=hidden_ch,
192
193
194
195
196
197
                       norm_layer=norm_layer, activation_layer=activation_layer),
            # SE
            squeeze_excite(
                cast(int, hidden_ch),
                cast(int, in_channels)
            ) if squeeze_excite is not None else nn.Identity(),
198
            # pw-linear
199
            ConvBNReLU(hidden_ch, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity),
200
201
        ]

202
        super().__init__(*simplify_sequential(layers))
203
204

    def forward(self, x: torch.Tensor) -> torch.Tensor:
205
        if self.has_skip:
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            return x + super().forward(x)
        else:
            return super().forward(x)


def inverted_residual_choice_builder(
    expand_ratios: List[int],
    kernel_sizes: List[int],
    downsample: bool,
    stage_input_width: int,
    stage_output_width: int,
    label: str
):
    def builder(index):
        stride = 1
        inp = stage_output_width

        if index == 0:
            # first layer in stage
            # do downsample and width reshape
            inp = stage_input_width
            if downsample:
                stride = 2

        oup = stage_output_width

        op_choices = {}
        for exp_ratio in expand_ratios:
            for kernel_size in kernel_sizes:
                op_choices[f'k{kernel_size}e{exp_ratio}'] = InvertedResidual(inp, oup, exp_ratio, kernel_size, stride)

        # It can be implemented with ValueChoice, but we use LayerChoice here
        # to be aligned with the intention of the original ProxylessNAS.
        return nn.LayerChoice(op_choices, label=f'{label}_i{index}')

    return builder


@model_wrapper
class ProxylessNAS(nn.Module):
    """
    The search space proposed by `ProxylessNAS <https://arxiv.org/abs/1812.00332>`__.

    Following the official implementation, the inverted residual with kernel size / expand ratio variations in each layer
250
    is implemented with a :class:`~nni.retiarii.nn.pytorch.LayerChoice` with all-combination candidates. That means,
251
    when used in weight sharing, these candidates will be treated as separate layers, and won't be fine-grained shared.
252
253
254
    We note that :class:`MobileNetV3Space` is different in this perspective.

    This space can be implemented as part of :class:`MobileNetV3Space`, but we separate those following conventions.
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    """

    def __init__(self, num_labels: int = 1000,
                 base_widths: Tuple[int, ...] = (32, 16, 32, 40, 80, 96, 192, 320, 1280),
                 dropout_rate: float = 0.,
                 width_mult: float = 1.0,
                 bn_eps: float = 1e-3,
                 bn_momentum: float = 0.1):

        super().__init__()

        assert len(base_widths) == 9
        # include the last stage info widths here
        widths = [make_divisible(width * width_mult, 8) for width in base_widths]
        downsamples = [True, False, True, True, True, False, True, False]

        self.num_labels = num_labels
        self.dropout_rate = dropout_rate
        self.bn_eps = bn_eps
        self.bn_momentum = bn_momentum

276
        self.stem = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
277

278
        blocks: List[nn.Module] = [
279
            # first stage is fixed
280
            DepthwiseSeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
281
282
283
284
285
286
287
288
        ]

        # https://github.com/ultmaster/AceNAS/blob/46c8895fd8a05ffbc61a6b44f1e813f64b4f66b7/searchspace/proxylessnas/__init__.py#L21
        for stage in range(2, 8):
            # Rather than returning a fixed module here,
            # we return a builder that dynamically creates module for different `repeat_idx`.
            builder = inverted_residual_choice_builder(
                [3, 6], [3, 5, 7], downsamples[stage], widths[stage - 1], widths[stage], f's{stage}')
289
            if stage < 7:
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
                blocks.append(nn.Repeat(builder, (1, 4), label=f's{stage}_depth'))
            else:
                # No mutation for depth in the last stage.
                # Directly call builder to initiate one block
                blocks.append(builder(0))

        self.blocks = nn.Sequential(*blocks)

        # final layers
        self.feature_mix_layer = ConvBNReLU(widths[7], widths[8], kernel_size=1, norm_layer=nn.BatchNorm2d)
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout_layer = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(widths[-1], num_labels)

        reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)

    def forward(self, x):
307
        x = self.stem(x)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        x = self.blocks(x)
        x = self.feature_mix_layer(x)
        x = self.global_avg_pooling(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.dropout_layer(x)
        x = self.classifier(x)
        return x

    def no_weight_decay(self):
        # this is useful for timm optimizer
        # no regularizer to linear layer
        if hasattr(self, 'classifier'):
            return {'classifier.weight', 'classifier.bias'}
        return set()

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    @classmethod
    def fixed_arch(cls, arch: dict) -> FixedFactory:
        return FixedFactory(cls, arch)

    @classmethod
    def load_searched_model(
        cls, name: str,
        pretrained: bool = False, download: bool = False, progress: bool = True
    ) -> nn.Module:

        init_kwargs = {}  # all default

        if name == 'acenas-m1':
            arch = {
                's2_depth': 2,
                's2_i0': 'k3e6',
                's2_i1': 'k3e3',
                's3_depth': 3,
                's3_i0': 'k5e3',
                's3_i1': 'k3e3',
                's3_i2': 'k5e3',
                's4_depth': 2,
                's4_i0': 'k3e6',
                's4_i1': 'k5e3',
                's5_depth': 4,
                's5_i0': 'k7e6',
                's5_i1': 'k3e6',
                's5_i2': 'k3e6',
                's5_i3': 'k7e3',
                's6_depth': 4,
                's6_i0': 'k7e6',
                's6_i1': 'k7e6',
                's6_i2': 'k7e3',
                's6_i3': 'k7e3',
                's7_depth': 1,
                's7_i0': 'k7e6'
            }

        elif name == 'acenas-m2':
            arch = {
                's2_depth': 1,
                's2_i0': 'k5e3',
                's3_depth': 3,
                's3_i0': 'k3e6',
                's3_i1': 'k3e3',
                's3_i2': 'k5e3',
                's4_depth': 2,
                's4_i0': 'k7e6',
                's4_i1': 'k5e6',
                's5_depth': 4,
                's5_i0': 'k5e6',
                's5_i1': 'k5e3',
                's5_i2': 'k5e6',
                's5_i3': 'k3e6',
                's6_depth': 4,
                's6_i0': 'k7e6',
                's6_i1': 'k5e6',
                's6_i2': 'k5e3',
                's6_i3': 'k5e6',
                's7_depth': 1,
                's7_i0': 'k7e6'
            }

        elif name == 'acenas-m3':
            arch = {
                's2_depth': 2,
                's2_i0': 'k3e3',
                's2_i1': 'k3e6',
                's3_depth': 2,
                's3_i0': 'k5e3',
                's3_i1': 'k3e3',
                's4_depth': 3,
                's4_i0': 'k5e6',
                's4_i1': 'k7e6',
                's4_i2': 'k3e6',
                's5_depth': 4,
                's5_i0': 'k7e6',
                's5_i1': 'k7e3',
                's5_i2': 'k7e3',
                's5_i3': 'k5e3',
                's6_depth': 4,
                's6_i0': 'k7e6',
                's6_i1': 'k7e3',
                's6_i2': 'k7e6',
                's6_i3': 'k3e3',
                's7_depth': 1,
                's7_i0': 'k5e6'
            }

        elif name == 'proxyless-cpu':
            arch = {
                's2_depth': 4,
                's2_i0': 'k3e6',
                's2_i1': 'k3e3',
                's2_i2': 'k3e3',
                's2_i3': 'k3e3',
                's3_depth': 4,
                's3_i0': 'k3e6',
                's3_i1': 'k3e3',
                's3_i2': 'k3e3',
                's3_i3': 'k5e3',
                's4_depth': 2,
                's4_i0': 'k3e6',
                's4_i1': 'k3e3',
                's5_depth': 4,
                's5_i0': 'k5e6',
                's5_i1': 'k3e3',
                's5_i2': 'k3e3',
                's5_i3': 'k3e3',
                's6_depth': 4,
                's6_i0': 'k5e6',
                's6_i1': 'k5e3',
                's6_i2': 'k5e3',
                's6_i3': 'k3e3',
                's7_depth': 1,
                's7_i0': 'k5e6'
            }

            init_kwargs['base_widths'] = [40, 24, 32, 48, 88, 104, 216, 360, 1432]

        elif name == 'proxyless-gpu':
            arch = {
                's2_depth': 1,
                's2_i0': 'k5e3',
                's3_depth': 2,
                's3_i0': 'k7e3',
                's3_i1': 'k3e3',
                's4_depth': 2,
                's4_i0': 'k7e6',
                's4_i1': 'k5e3',
                's5_depth': 3,
                's5_i0': 'k5e6',
                's5_i1': 'k3e3',
                's5_i2': 'k5e3',
                's6_depth': 4,
                's6_i0': 'k7e6',
                's6_i1': 'k7e6',
                's6_i2': 'k7e6',
                's6_i3': 'k5e6',
                's7_depth': 1,
                's7_i0': 'k7e6'
            }

            init_kwargs['base_widths'] = [40, 24, 32, 56, 112, 128, 256, 432, 1728]

        elif name == 'proxyless-mobile':
            arch = {
                's2_depth': 2,
                's2_i0': 'k5e3',
                's2_i1': 'k3e3',
                's3_depth': 4,
                's3_i0': 'k7e3',
                's3_i1': 'k3e3',
                's3_i2': 'k5e3',
                's3_i3': 'k5e3',
                's4_depth': 4,
                's4_i0': 'k7e6',
                's4_i1': 'k5e3',
                's4_i2': 'k5e3',
                's4_i3': 'k5e3',
                's5_depth': 4,
                's5_i0': 'k5e6',
                's5_i1': 'k5e3',
                's5_i2': 'k5e3',
                's5_i3': 'k5e3',
                's6_depth': 4,
                's6_i0': 'k7e6',
                's6_i1': 'k7e6',
                's6_i2': 'k7e3',
                's6_i3': 'k7e3',
                's7_depth': 1,
                's7_i0': 'k7e6'
            }

        else:
            raise ValueError(f'Unsupported architecture with name: {name}')

        model_factory = cls.fixed_arch(arch)
        model = model_factory(**init_kwargs)

        if pretrained:
            weight_file = load_pretrained_weight(name, download=download, progress=progress)
            pretrained_weights = torch.load(weight_file)
            model.load_state_dict(pretrained_weights)

        return model

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538

def reset_parameters(model, model_init='he_fout', init_div_groups=False,
                     bn_momentum=0.1, bn_eps=1e-5):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            if model_init == 'he_fout':
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                if init_div_groups:
                    n /= m.groups
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif model_init == 'he_fin':
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                if init_div_groups:
                    n /= m.groups
                m.weight.data.normal_(0, math.sqrt(2. / n))
            else:
                raise NotImplementedError
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
            m.momentum = bn_momentum
            m.eps = bn_eps
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm1d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()