encnet.py 7.5 KB
Newer Older
Zhang's avatar
v0.4.2  
Zhang committed
1
2
3
4
5
6
7
8
9
###########################################################################
# Created by: Hang Zhang 
# Email: zhang.hang@rutgers.edu 
# Copyright (c) 2017
###########################################################################

import torch
from torch.autograd import Variable
import torch.nn as nn
Hang Zhang's avatar
Hang Zhang committed
10
import torch.nn.functional as F
Zhang's avatar
v0.4.2  
Zhang committed
11
12
13
14
15

import encoding
from .base import BaseNet
from .fcn import FCNHead

Hang Zhang's avatar
Hang Zhang committed
16
__all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
Hang Zhang's avatar
Hang Zhang committed
17
           'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade']
Zhang's avatar
v0.4.2  
Zhang committed
18
19

class EncNet(BaseNet):
Hang Zhang's avatar
Hang Zhang committed
20
    def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
Zhang's avatar
v0.4.2  
Zhang committed
21
22
23
                 norm_layer=nn.BatchNorm2d, **kwargs):
        super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
        self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
Hang Zhang's avatar
Hang Zhang committed
24
25
                            lateral=lateral, norm_layer=norm_layer,
                            up_kwargs=self._up_kwargs)
Zhang's avatar
v0.4.2  
Zhang committed
26
27
28
29
30
        if aux:
            self.auxlayer = FCNHead(1024, nclass, norm_layer=norm_layer)

    def forward(self, x):
        imsize = x.size()[2:]
Hang Zhang's avatar
Hang Zhang committed
31
        features = self.base_forward(x)
Zhang's avatar
v0.4.2  
Zhang committed
32

Hang Zhang's avatar
Hang Zhang committed
33
34
        x = list(self.head(*features))
        x[0] = F.upsample(x[0], imsize, **self._up_kwargs)
Zhang's avatar
v0.4.2  
Zhang committed
35
        if self.aux:
Hang Zhang's avatar
Hang Zhang committed
36
37
            auxout = self.auxlayer(features[2])
            auxout = F.upsample(auxout, imsize, **self._up_kwargs)
Zhang's avatar
v0.4.2  
Zhang committed
38
39
40
41
42
43
44
            x.append(auxout)
        return tuple(x)


class EncModule(nn.Module):
    def __init__(self, in_channels, nclass, ncodes=32, se_loss=True, norm_layer=None):
        super(EncModule, self).__init__()
Hang Zhang's avatar
Hang Zhang committed
45
46
        norm_layer = nn.BatchNorm1d if isinstance(norm_layer, nn.BatchNorm2d) else \
            encoding.nn.BatchNorm1d
Zhang's avatar
v0.4.2  
Zhang committed
47
48
        self.se_loss = se_loss
        self.encoding = nn.Sequential(
Hang Zhang's avatar
Hang Zhang committed
49
50
51
            nn.Conv2d(in_channels, in_channels, 1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
Zhang's avatar
v0.4.2  
Zhang committed
52
53
54
            encoding.nn.Encoding(D=in_channels, K=ncodes),
            norm_layer(ncodes),
            nn.ReLU(inplace=True),
Hang Zhang's avatar
Hang Zhang committed
55
            encoding.nn.Mean(dim=1))
Zhang's avatar
v0.4.2  
Zhang committed
56
57
58
59
60
61
62
63
64
65
66
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.Sigmoid())
        if self.se_loss:
            self.selayer = nn.Linear(in_channels, nclass)

    def forward(self, x):
        en = self.encoding(x)
        b, c, _, _ = x.size()
        gamma = self.fc(en)
        y = gamma.view(b, c, 1, 1)
Hang Zhang's avatar
Hang Zhang committed
67
        outputs = [F.relu_(x + x * y)]
Zhang's avatar
v0.4.2  
Zhang committed
68
69
70
71
72
73
        if self.se_loss:
            outputs.append(self.selayer(en))
        return tuple(outputs)


class EncHead(nn.Module):
Hang Zhang's avatar
Hang Zhang committed
74
    def __init__(self, out_channels, in_channels, se_loss=True, lateral=True,
Zhang's avatar
v0.4.2  
Zhang committed
75
76
                 norm_layer=None, up_kwargs=None):
        super(EncHead, self).__init__()
Hang Zhang's avatar
Hang Zhang committed
77
78
79
        self.se_loss = se_loss
        self.lateral = lateral
        self.up_kwargs = up_kwargs
Zhang's avatar
v0.4.2  
Zhang committed
80
81
82
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels, 512, 3, padding=1, bias=False),
            norm_layer(512),
Hang Zhang's avatar
Hang Zhang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            nn.ReLU(inplace=True))
        if lateral:
            self.connect = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(512, 512, kernel_size=1, bias=False),
                    norm_layer(512),
                    nn.ReLU(inplace=True)),
                nn.Sequential(
                    nn.Conv2d(1024, 512, kernel_size=1, bias=False),
                    norm_layer(512),
                    nn.ReLU(inplace=True)),
            ])
            self.fusion = nn.Sequential(
                    nn.Conv2d(3*512, 512, kernel_size=3, padding=1, bias=False),
                    norm_layer(512),
                    nn.ReLU(inplace=True))
Zhang's avatar
v0.4.2  
Zhang committed
99
100
        self.encmodule = EncModule(512, out_channels, ncodes=32,
            se_loss=se_loss, norm_layer=norm_layer)
Hang Zhang's avatar
Hang Zhang committed
101
102
103
104
105
106
107
108
109
110
111
        self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False),
                                   nn.Conv2d(512, out_channels, 1))

    def forward(self, *inputs):
        feat = self.conv5(inputs[-1])
        if self.lateral:
            c2 = self.connect[0](inputs[1])
            c3 = self.connect[1](inputs[2])
            feat = self.fusion(torch.cat([feat, c2, c3], 1))
        outs = list(self.encmodule(feat))
        outs[0] = self.conv6(outs[0])
Zhang's avatar
v0.4.2  
Zhang committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        return tuple(outs)


def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
               root='~/.encoding/models', **kwargs):
    r"""EncNet model from the paper `"Context Encoding for Semantic Segmentation"
    <https://arxiv.org/pdf/1803.08904.pdf>`_

    Parameters
    ----------
    dataset : str, default pascal_voc
        The dataset that model pretrained on. (pascal_voc, ade20k)
    backbone : str, default resnet50
        The backbone network. (resnet50, 101, 152)
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    root : str, default '~/.encoding/models'
        Location for keeping the model parameters.


    Examples
    --------
    >>> model = get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False)
    >>> print(model)
    """
    acronyms = {
        'pascal_voc': 'voc',
        'ade20k': 'ade',
        'pcontext': 'pcontext',
    }
Hang Zhang's avatar
Hang Zhang committed
142
    kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
Zhang's avatar
v0.4.2  
Zhang committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    # infer number of classes
    from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
    model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
    if pretrained:
        from .model_store import get_model_file
        model.load_state_dict(torch.load(
            get_model_file('encnet_%s_%s'%(backbone, acronyms[dataset]), root=root)))
    return model

def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
    r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
    <https://arxiv.org/pdf/1803.08904.pdf>`_

    Parameters
    ----------
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    root : str, default '~/.encoding/models'
        Location for keeping the model parameters.


    Examples
    --------
    >>> model = get_encnet_resnet50_pcontext(pretrained=True)
    >>> print(model)
    """
Hang Zhang's avatar
Hang Zhang committed
169
    return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)
Hang Zhang's avatar
Hang Zhang committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
    r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
    <https://arxiv.org/pdf/1803.08904.pdf>`_

    Parameters
    ----------
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    root : str, default '~/.encoding/models'
        Location for keeping the model parameters.


    Examples
    --------
    >>> model = get_encnet_resnet101_pcontext(pretrained=True)
    >>> print(model)
    """
Hang Zhang's avatar
Hang Zhang committed
188
    return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=False, **kwargs)
Hang Zhang's avatar
Hang Zhang committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
    r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
    <https://arxiv.org/pdf/1803.08904.pdf>`_

    Parameters
    ----------
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    root : str, default '~/.encoding/models'
        Location for keeping the model parameters.


    Examples
    --------
    >>> model = get_encnet_resnet50_ade(pretrained=True)
    >>> print(model)
    """
Hang Zhang's avatar
Hang Zhang committed
207
    return get_encnet('ade20k', 'resnet50', pretrained, root=root, aux=True, **kwargs)