Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
aecbb150
"tests/python/common/test_batch-heterograph.py" did not exist on "5747542f35e549d71cf6852678349373981291e6"
Unverified
Commit
aecbb150
authored
Jan 27, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Jan 27, 2022
Browse files
Add IntermediateLayerGetter on segmentation. (#5298)
parent
b94004a6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
7 deletions
+7
-7
torchvision/models/segmentation/deeplabv3.py
torchvision/models/segmentation/deeplabv3.py
+3
-3
torchvision/models/segmentation/fcn.py
torchvision/models/segmentation/fcn.py
+2
-2
torchvision/models/segmentation/lraspp.py
torchvision/models/segmentation/lraspp.py
+2
-2
No files found.
torchvision/models/segmentation/deeplabv3.py
View file @
aecbb150
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
from
..
import
mobilenetv3
from
..
import
mobilenetv3
from
..
import
resnet
from
..
import
resnet
from
..
feature_extraction
import
create_feature_extracto
r
from
..
_utils
import
IntermediateLayerGette
r
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
from
.fcn
import
FCNHead
from
.fcn
import
FCNHead
...
@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
...
@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
return_layers
=
{
"layer4"
:
"out"
}
return_layers
=
{
"layer4"
:
"out"
}
if
aux
:
if
aux
:
return_layers
[
"layer3"
]
=
"aux"
return_layers
[
"layer3"
]
=
"aux"
backbone
=
create_feature_extractor
(
backbone
,
return_layers
)
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
return_layers
)
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
classifier
=
DeepLabHead
(
2048
,
num_classes
)
classifier
=
DeepLabHead
(
2048
,
num_classes
)
...
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
...
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
return_layers
=
{
str
(
out_pos
):
"out"
}
return_layers
=
{
str
(
out_pos
):
"out"
}
if
aux
:
if
aux
:
return_layers
[
str
(
aux_pos
)]
=
"aux"
return_layers
[
str
(
aux_pos
)]
=
"aux"
backbone
=
create_feature_extractor
(
backbone
,
return_layers
)
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
return_layers
)
aux_classifier
=
FCNHead
(
aux_inplanes
,
num_classes
)
if
aux
else
None
aux_classifier
=
FCNHead
(
aux_inplanes
,
num_classes
)
if
aux
else
None
classifier
=
DeepLabHead
(
out_inplanes
,
num_classes
)
classifier
=
DeepLabHead
(
out_inplanes
,
num_classes
)
...
...
torchvision/models/segmentation/fcn.py
View file @
aecbb150
...
@@ -3,7 +3,7 @@ from typing import Optional
...
@@ -3,7 +3,7 @@ from typing import Optional
from
torch
import
nn
from
torch
import
nn
from
..
import
resnet
from
..
import
resnet
from
..
feature_extraction
import
create_feature_extracto
r
from
..
_utils
import
IntermediateLayerGette
r
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
from
._utils
import
_SimpleSegmentationModel
,
_load_weights
...
@@ -57,7 +57,7 @@ def _fcn_resnet(
...
@@ -57,7 +57,7 @@ def _fcn_resnet(
return_layers
=
{
"layer4"
:
"out"
}
return_layers
=
{
"layer4"
:
"out"
}
if
aux
:
if
aux
:
return_layers
[
"layer3"
]
=
"aux"
return_layers
[
"layer3"
]
=
"aux"
backbone
=
create_feature_extractor
(
backbone
,
return_layers
)
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
return_layers
)
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
aux_classifier
=
FCNHead
(
1024
,
num_classes
)
if
aux
else
None
classifier
=
FCNHead
(
2048
,
num_classes
)
classifier
=
FCNHead
(
2048
,
num_classes
)
...
...
torchvision/models/segmentation/lraspp.py
View file @
aecbb150
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
...
@@ -6,7 +6,7 @@ from torch.nn import functional as F
from
...utils
import
_log_api_usage_once
from
...utils
import
_log_api_usage_once
from
..
import
mobilenetv3
from
..
import
mobilenetv3
from
..
feature_extraction
import
create_feature_extracto
r
from
..
_utils
import
IntermediateLayerGette
r
from
._utils
import
_load_weights
from
._utils
import
_load_weights
...
@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
...
@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
high_pos
=
stage_indices
[
-
1
]
# use C5 which has output_stride = 16
high_pos
=
stage_indices
[
-
1
]
# use C5 which has output_stride = 16
low_channels
=
backbone
[
low_pos
].
out_channels
low_channels
=
backbone
[
low_pos
].
out_channels
high_channels
=
backbone
[
high_pos
].
out_channels
high_channels
=
backbone
[
high_pos
].
out_channels
backbone
=
create_feature_extractor
(
backbone
,
{
str
(
low_pos
):
"low"
,
str
(
high_pos
):
"high"
})
backbone
=
IntermediateLayerGetter
(
backbone
,
return_layers
=
{
str
(
low_pos
):
"low"
,
str
(
high_pos
):
"high"
})
return
LRASPP
(
backbone
,
low_channels
,
high_channels
,
num_classes
)
return
LRASPP
(
backbone
,
low_channels
,
high_channels
,
num_classes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment