blocks.py 3.33 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn


class ShuffleNetBlock(nn.Module):
    """
    When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels.
    """

Mingyao Li's avatar
Mingyao Li committed
13
    def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine=True):
Yuge Zhang's avatar
Yuge Zhang committed
14
15
16
17
18
19
20
21
22
23
24
        super().__init__()
        assert stride in [1, 2]
        assert ksize in [3, 5, 7]
        self.channels = inp // 2 if stride == 1 else inp
        self.inp = inp
        self.oup = oup
        self.mid_channels = mid_channels
        self.ksize = ksize
        self.stride = stride
        self.pad = ksize // 2
        self.oup_main = oup - self.channels
Mingyao Li's avatar
Mingyao Li committed
25
        self._affine = affine
Yuge Zhang's avatar
Yuge Zhang committed
26
27
28
29
30
31
32
33
34
        assert self.oup_main > 0

        self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))

        if stride == 2:
            self.branch_proj = nn.Sequential(
                # dw
                nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
                          groups=self.channels, bias=False),
Mingyao Li's avatar
Mingyao Li committed
35
                nn.BatchNorm2d(self.channels, affine=affine),
Yuge Zhang's avatar
Yuge Zhang committed
36
37
                # pw-linear
                nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
Mingyao Li's avatar
Mingyao Li committed
38
                nn.BatchNorm2d(self.channels, affine=affine),
Yuge Zhang's avatar
Yuge Zhang committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                nn.ReLU(inplace=True)
            )

    def forward(self, x):
        if self.stride == 2:
            x_proj, x = self.branch_proj(x), x
        else:
            x_proj, x = self._channel_shuffle(x)
        return torch.cat((x_proj, self.branch_main(x)), 1)

    def _decode_point_depth_conv(self, sequence):
        result = []
        first_depth = first_point = True
        pc = c = self.channels
        for i, token in enumerate(sequence):
            # compute output channels of this conv
            if i + 1 == len(sequence):
                assert token == "p", "Last conv must be point-wise conv."
                c = self.oup_main
            elif token == "p" and first_point:
                c = self.mid_channels
            if token == "d":
                # depth-wise conv
                assert pc == c, "Depth-wise conv must not change channels."
                result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad,
                                        groups=c, bias=False))
Mingyao Li's avatar
Mingyao Li committed
65
                result.append(nn.BatchNorm2d(c, affine=self._affine))
Yuge Zhang's avatar
Yuge Zhang committed
66
67
68
69
                first_depth = False
            elif token == "p":
                # point-wise conv
                result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
Mingyao Li's avatar
Mingyao Li committed
70
                result.append(nn.BatchNorm2d(c, affine=self._affine))
Yuge Zhang's avatar
Yuge Zhang committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                result.append(nn.ReLU(inplace=True))
                first_point = False
            else:
                raise ValueError("Conv sequence must be d and p.")
            pc = c
        return result

    def _channel_shuffle(self, x):
        bs, num_channels, height, width = x.data.size()
        assert (num_channels % 4 == 0)
        x = x.reshape(bs * num_channels // 2, 2, height * width)
        x = x.permute(1, 0, 2)
        x = x.reshape(2, -1, num_channels // 2, height, width)
        return x[0], x[1]


class ShuffleXceptionBlock(ShuffleNetBlock):

Mingyao Li's avatar
Mingyao Li committed
89
90
    def __init__(self, inp, oup, mid_channels, stride, affine=True):
        super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp", affine)