Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
aecbb150
Unverified
Commit
aecbb150
authored
Jan 27, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Jan 27, 2022
Browse files
Add IntermediateLayerGetter on segmentation. (#5298)
parent
b94004a6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
7 deletions
+7
-7
torchvision/models/segmentation/deeplabv3.py
torchvision/models/segmentation/deeplabv3.py
+3
-3
torchvision/models/segmentation/fcn.py
torchvision/models/segmentation/fcn.py
+2
-2
torchvision/models/segmentation/lraspp.py
torchvision/models/segmentation/lraspp.py
+2
-2
No files found.
torchvision/models/segmentation/deeplabv3.py
View file @
aecbb150
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
from
..
import
mobilenetv3
from
..
import
mobilenetv3
from
..
import
resnet
from
..
import
resnet
from
..
feature_extraction
import
create_feature_extracto
r
from
..
_utils
import
IntermediateLayerGette
r
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
from
.fcn
import
FCNHead
from
.fcn
import
FCNHead
...
@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
...
@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
return_layers
=
{
"layer4"
:
"out"
}
return_layers
=
{
"layer4"
:
"out"
}
if
aux
:
if
aux
:
return_layers
[
"layer3"
]
=
"aux"
return_layers
[
"layer3"
]
=
"aux"
backbone
=
create_feature_extractor
(
backbone
,
return_layers
)
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
return_layers
)
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
classifier
=
DeepLabHead
(
2048
,
num_classes
)
classifier
=
DeepLabHead
(
2048
,
num_classes
)
...
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
...
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
return_layers
=
{
str
(
out_pos
):
"out"
}
return_layers
=
{
str
(
out_pos
):
"out"
}
if
aux
:
if
aux
:
return_layers
[
str
(
aux_pos
)]
=
"aux"
return_layers
[
str
(
aux_pos
)]
=
"aux"
backbone
=
create_feature_extractor
(
backbone
,
return_layers
)
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
return_layers
)
aux_classifier
=
FCNHead
(
aux_inplanes
,
num_classes
)
if
aux
else
None
aux_classifier
=
FCNHead
(
aux_inplanes
,
num_classes
)
if
aux
else
None
classifier
=
DeepLabHead
(
out_inplanes
,
num_classes
)
classifier
=
DeepLabHead
(
out_inplanes
,
num_classes
)
...
...
torchvision/models/segmentation/fcn.py
View file @
aecbb150
...
@@ -3,7 +3,7 @@ from typing import Optional
...
@@ -3,7 +3,7 @@ from typing import Optional
from
torch
import
nn
from
torch
import
nn
from
..
import
resnet
from
..
import
resnet
from
..
feature_extraction
import
create_feature_extracto
r
from
..
_utils
import
IntermediateLayerGette
r
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
...
@@ -57,7 +57,7 @@ def _fcn_resnet(
...
@@ -57,7 +57,7 @@ def _fcn_resnet(
return_layers
=
{
"layer4"
:
"out"
}
return_layers
=
{
"layer4"
:
"out"
}
if
aux
:
if
aux
:
return_layers
[
"layer3"
]
=
"aux"
return_layers
[
"layer3"
]
=
"aux"
backbone
=
create_feature_extractor
(
backbone
,
return_layers
)
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
return_layers
)
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
classifier
=
FCNHead
(
2048
,
num_classes
)
classifier
=
FCNHead
(
2048
,
num_classes
)
...
...
torchvision/models/segmentation/lraspp.py
View file @
aecbb150
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
from
...utils
import
_log_api_usage_once
from
...utils
import
_log_api_usage_once
from
..
import
mobilenetv3
from
..
import
mobilenetv3
from
..
feature_extraction
import
create_feature_extracto
r
from
..
_utils
import
IntermediateLayerGette
r
from
._utils
import
_load_weights
from
._utils
import
_load_weights
...
@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
...
@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
high_pos
=
stage_indices
[
-
1
]
# use C5 which has output_stride = 16
high_pos
=
stage_indices
[
-
1
]
# use C5 which has output_stride = 16
low_channels
=
backbone
[
low_pos
].
out_channels
low_channels
=
backbone
[
low_pos
].
out_channels
high_channels
=
backbone
[
high_pos
].
out_channels
high_channels
=
backbone
[
high_pos
].
out_channels
backbone
=
create_feature_extractor
(
backbone
,
{
str
(
low_pos
):
"low"
,
str
(
high_pos
):
"high"
})
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
{
str
(
low_pos
):
"low"
,
str
(
high_pos
):
"high"
})
return
LRASPP
(
backbone
,
low_channels
,
high_channels
,
num_classes
)
return
LRASPP
(
backbone
,
low_channels
,
high_channels
,
num_classes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment