"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "2c11f9c2eba40e922a0e5a6a4205cc37b17d9165"
deeplabv3.py 3.1 KB
Newer Older
1
2
from typing import List

3
4
5
6
import torch
from torch import nn
from torch.nn import functional as F

7
from ._utils import _SimpleSegmentationModel
8
9


10
11
12
__all__ = ["DeepLabV3"]


13
class DeepLabV3(_SimpleSegmentationModel):
14
15
16
17
18
    """
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

19
    Args:
20
21
22
23
24
25
26
27
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
28

29
30
31
32
    pass


class DeepLabHead(nn.Sequential):
33
    def __init__(self, in_channels: int, num_classes: int) -> None:
34
35
36
37
38
        super(DeepLabHead, self).__init__(
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
39
            nn.Conv2d(256, num_classes, 1),
40
41
42
43
        )


class ASPPConv(nn.Sequential):
44
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
45
46
47
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
48
            nn.ReLU(),
49
50
51
52
53
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
54
    def __init__(self, in_channels: int, out_channels: int) -> None:
55
56
57
58
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
59
60
            nn.ReLU(),
        )
61

62
    def forward(self, x: torch.Tensor) -> torch.Tensor:
63
        size = x.shape[-2:]
eellison's avatar
eellison committed
64
65
        for mod in self:
            x = mod(x)
66
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
67
68
69


class ASPP(nn.Module):
70
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
71
72
        super(ASPP, self).__init__()
        modules = []
73
74
75
        modules.append(
            nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
        )
76

77
78
79
80
        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

81
82
83
84
85
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
Adeel Hassan's avatar
Adeel Hassan committed
86
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
87
88
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
89
90
            nn.Dropout(0.5),
        )
91

92
93
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
94
        for conv in self.convs:
95
96
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
97
        return self.project(res)