segmentation.py 1.85 KB
Newer Older
1
2
3
4
5
6
7
from .._utils import IntermediateLayerGetter
from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead


__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=[False, True, True])

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = FCNHead(inplanes, num_classes)

    model_map = {
        'deeplab': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model


def fcn_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
    model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs)
    if pretrained:
        pass
    return model


def fcn_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
    model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs)
    if pretrained:
        pass
    return model


def deeplabv3_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
    model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs)
    if pretrained:
        pass
    return model


def deeplabv3_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs):
    model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs)
    if pretrained:
        pass
    return model