Unverified Commit aecbb150 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add IntermediateLayerGetter on segmentation. (#5298)

parent b94004a6
......@@ -6,7 +6,7 @@ from torch.nn import functional as F
from .. import mobilenetv3
from .. import resnet
from ..feature_extraction import create_feature_extractor
from .._utils import IntermediateLayerGetter
from ._utils import _SimpleSegmentationModel, _load_weights
from .fcn import FCNHead
......@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = DeepLabHead(2048, num_classes)
......@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
return_layers = {str(out_pos): "out"}
if aux:
return_layers[str(aux_pos)] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
classifier = DeepLabHead(out_inplanes, num_classes)
......
......@@ -3,7 +3,7 @@ from typing import Optional
from torch import nn
from .. import resnet
from ..feature_extraction import create_feature_extractor
from .._utils import IntermediateLayerGetter
from ._utils import _SimpleSegmentationModel, _load_weights
......@@ -57,7 +57,7 @@ def _fcn_resnet(
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = FCNHead(2048, num_classes)
......
......@@ -6,7 +6,7 @@ from torch.nn import functional as F
from ...utils import _log_api_usage_once
from .. import mobilenetv3
from ..feature_extraction import create_feature_extractor
from .._utils import IntermediateLayerGetter
from ._utils import _load_weights
......@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
return LRASPP(backbone, low_channels, high_channels, num_classes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment