segmentation.py 4.03 KB
Newer Older
1
from .._utils import IntermediateLayerGetter
2
from ..utils import load_state_dict_from_url
3
4
5
6
7
8
from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead


__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
9
10


11
12
13
14
15
16
17
18
model_urls = {
    'fcn_resnet50_coco': None,
    'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
    'deeplabv3_resnet50_coco': None,
    'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
}


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=[False, True, True])

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
ekka's avatar
ekka committed
35
        'deeplabv3': (DeepLabHead, DeepLabV3),
36
37
38
39
40
41
42
43
44
45
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model


ekka's avatar
ekka committed
46
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
47
48
    if pretrained:
        aux_loss = True
ekka's avatar
ekka committed
49
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
50
    if pretrained:
ekka's avatar
ekka committed
51
        arch = arch_type + '_' + backbone + '_coco'
52
53
54
55
56
57
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)
58
59
60
    return model


ekka's avatar
ekka committed
61
62
63
64
65
66
67
68
69
70
71
72
def fcn_resnet50(pretrained=False, progress=True,
                 num_classes=21, aux_loss=None, **kwargs):
    """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)


73
74
75
76
77
78
79
80
81
def fcn_resnet101(pretrained=False, progress=True,
                  num_classes=21, aux_loss=None, **kwargs):
    """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
    """
ekka's avatar
ekka committed
82
    return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
83
84


85
86
87
88
89
90
91
92
93
def deeplabv3_resnet50(pretrained=False, progress=True,
                       num_classes=21, aux_loss=None, **kwargs):
    """Constructs a DeepLabV3 model with a ResNet-50 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
    """
ekka's avatar
ekka committed
94
    return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
95
96


97
98
99
100
101
102
103
104
105
def deeplabv3_resnet101(pretrained=False, progress=True,
                        num_classes=21, aux_loss=None, **kwargs):
    """Constructs a DeepLabV3 model with a ResNet-101 backbone.

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
            contains the same classes as Pascal VOC
        progress (bool): If True, displays a progress bar of the download to stderr
    """
ekka's avatar
ekka committed
106
    return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)