".github/vscode:/vscode.git/clone" did not exist on "997253143f36f8988c04f2bb184b0e8bb6e154d4"
deeplabv3.py 3.1 KB
Newer Older
1
2
3
import torch
from torch import nn
from torch.nn import functional as F
4
from typing import List
5

6
from ._utils import _SimpleSegmentationModel
7
8


9
10
11
__all__ = ["DeepLabV3"]


12
class DeepLabV3(_SimpleSegmentationModel):
13
14
15
16
17
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

18
    Args:
19
20
21
22
23
24
25
26
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
27
28
29
30
    pass


class DeepLabHead(nn.Sequential):
31
    def __init__(self, in_channels: int, num_classes: int) -> None:
32
33
34
35
36
37
38
39
40
41
        super(DeepLabHead, self).__init__(
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1)
        )


class ASPPConv(nn.Sequential):
42
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
43
44
45
46
47
48
49
50
51
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
52
    def __init__(self, in_channels: int, out_channels: int) -> None:
53
54
55
56
57
58
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

59
    def forward(self, x: torch.Tensor) -> torch.Tensor:
60
        size = x.shape[-2:]
eellison's avatar
eellison committed
61
62
        for mod in self:
            x = mod(x)
63
64
65
66
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):
67
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
68
69
70
71
72
73
74
        super(ASPP, self).__init__()
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()))

75
76
77
78
        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

79
80
81
82
83
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
Adeel Hassan's avatar
Adeel Hassan committed
84
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
85
86
87
88
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

89
90
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
91
        for conv in self.convs:
92
93
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
94
        return self.project(res)