Unverified Commit 86db394e authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Move segmentation models to its own folder (#918)

* Move segmentation models to its own folder

* Add missing files
parent 967ef26c
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
class _SimpleSegmentationModel(nn.Module):
def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x):
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["aux"] = x
return result
from collections import OrderedDict
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from ._utils import _SimpleSegmentationModel
class _SimpleSegmentationModel(nn.Module):
def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x):
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result["aux"] = x
return result
class FCN(_SimpleSegmentationModel):
pass
class DeepLabV3(_SimpleSegmentationModel): class DeepLabV3(_SimpleSegmentationModel):
pass pass
class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1)
]
super(FCNHead, self).__init__(*layers)
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
"""
class DeepLabHead(nn.Sequential): class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes): def __init__(self, in_channels, num_classes):
super(DeepLabHead, self).__init__( super(DeepLabHead, self).__init__(
...@@ -71,14 +18,6 @@ class DeepLabHead(nn.Sequential): ...@@ -71,14 +18,6 @@ class DeepLabHead(nn.Sequential):
nn.ReLU(), nn.ReLU(),
nn.Conv2d(256, num_classes, 1) nn.Conv2d(256, num_classes, 1)
) )
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
"""
class ASPPConv(nn.Sequential): class ASPPConv(nn.Sequential):
......
from torch import nn
from ._utils import _SimpleSegmentationModel
class FCN(_SimpleSegmentationModel):
pass
class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1)
]
super(FCNHead, self).__init__(*layers)
from ._utils import IntermediateLayerGetter from .._utils import IntermediateLayerGetter
from . import resnet from .. import resnet
from .deeplabv3 import FCN, FCNHead, DeepLabHead, DeepLabV3 from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment