"vscode:/vscode.git/clone" did not exist on "df73d3e110347d5ae7808f6492a6b115aba257b7"
Unverified Commit 4614cf93 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Use FX feature extractor for segm model (#4563)

* Use FX feature extractor for segm model

* Removed use_fe option
parent 2256b495
......@@ -5,7 +5,7 @@ from torch import nn
from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3
from .. import resnet
from .._utils import IntermediateLayerGetter
from ..feature_extraction import create_feature_extractor
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
from .lraspp import LRASPP
......@@ -60,7 +60,7 @@ def _segm_model(
return_layers = {out_layer: "out"}
if aux:
return_layers[aux_layer] = "aux"
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = None
if aux:
......@@ -116,7 +116,7 @@ def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_ba
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment