from .._utils import IntermediateLayerGetter from ..utils import load_state_dict_from_url from .. import resnet from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead __all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101'] model_urls = { 'fcn_resnet50_coco': None, 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', 'deeplabv3_resnet50_coco': None, 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', } def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): backbone = resnet.__dict__[backbone_name]( pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) return_layers = {'layer4': 'out'} if aux: return_layers['layer3'] = 'aux' backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: inplanes = 1024 aux_classifier = FCNHead(inplanes, num_classes) model_map = { 'deeplab': (DeepLabHead, DeepLabV3), 'fcn': (FCNHead, FCN), } inplanes = 2048 classifier = model_map[name][0](inplanes, num_classes) base_model = model_map[name][1] model = base_model(backbone, classifier, aux_classifier) return model def fcn_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: aux_loss = True model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs) if pretrained: arch = 'fcn_resnet50_coco' model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model def fcn_resnet101(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: aux_loss = True model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs) if pretrained: arch = 'fcn_resnet101_coco' model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model def deeplabv3_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: aux_loss = True model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs) if pretrained: arch = 'deeplabv3_resnet50_coco' model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model def deeplabv3_resnet101(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs): """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: pretrained (bool): If True, returns a model pre-trained on COCO train2017 which contains the same classes as Pascal VOC progress (bool): If True, displays a progress bar of the download to stderr """ if pretrained: aux_loss = True model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs) if pretrained: arch = 'deeplabv3_resnet101_coco' model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model