Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
090d8237
Unverified
Commit
090d8237
authored
Mar 11, 2022
by
talregev
Committed by
GitHub
Mar 11, 2022
Browse files
Improve test of backbone utils (#5552)
parent
a8bde781
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
2 deletions
+8
-2
test/test_backbone_utils.py
test/test_backbone_utils.py
+8
-2
No files found.
test/test_backbone_utils.py
View file @
090d8237
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
common_utils
import
set_rng_seed
from
common_utils
import
set_rng_seed
from
torchvision
import
models
from
torchvision
import
models
from
torchvision.models._utils
import
IntermediateLayerGetter
from
torchvision.models._utils
import
IntermediateLayerGetter
from
torchvision.models.detection.backbone_utils
import
mobilenet_backbone
,
resnet_fpn_backbone
from
torchvision.models.detection.backbone_utils
import
BackboneWithFPN
,
mobilenet_backbone
,
resnet_fpn_backbone
from
torchvision.models.feature_extraction
import
create_feature_extractor
,
get_graph_node_names
from
torchvision.models.feature_extraction
import
create_feature_extractor
,
get_graph_node_names
...
@@ -19,7 +19,9 @@ def get_available_models():
...
@@ -19,7 +19,9 @@ def get_available_models():
@
pytest
.
mark
.
parametrize
(
"backbone_name"
,
(
"resnet18"
,
"resnet50"
))
@
pytest
.
mark
.
parametrize
(
"backbone_name"
,
(
"resnet18"
,
"resnet50"
))
def
test_resnet_fpn_backbone
(
backbone_name
):
def
test_resnet_fpn_backbone
(
backbone_name
):
x
=
torch
.
rand
(
1
,
3
,
300
,
300
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
x
=
torch
.
rand
(
1
,
3
,
300
,
300
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
y
=
resnet_fpn_backbone
(
backbone_name
=
backbone_name
,
pretrained
=
False
)(
x
)
model
=
resnet_fpn_backbone
(
backbone_name
=
backbone_name
,
pretrained
=
False
)
assert
isinstance
(
model
,
BackboneWithFPN
)
y
=
model
(
x
)
assert
list
(
y
.
keys
())
==
[
"0"
,
"1"
,
"2"
,
"3"
,
"pool"
]
assert
list
(
y
.
keys
())
==
[
"0"
,
"1"
,
"2"
,
"3"
,
"pool"
]
with
pytest
.
raises
(
ValueError
,
match
=
r
"Trainable layers should be in the range"
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"Trainable layers should be in the range"
):
...
@@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name):
...
@@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name):
mobilenet_backbone
(
backbone_name
,
False
,
fpn
=
True
,
returned_layers
=
[
-
1
,
0
,
1
,
2
])
mobilenet_backbone
(
backbone_name
,
False
,
fpn
=
True
,
returned_layers
=
[
-
1
,
0
,
1
,
2
])
with
pytest
.
raises
(
ValueError
,
match
=
r
"Each returned layer should be in the range"
):
with
pytest
.
raises
(
ValueError
,
match
=
r
"Each returned layer should be in the range"
):
mobilenet_backbone
(
backbone_name
,
False
,
fpn
=
True
,
returned_layers
=
[
3
,
4
,
5
,
6
])
mobilenet_backbone
(
backbone_name
,
False
,
fpn
=
True
,
returned_layers
=
[
3
,
4
,
5
,
6
])
model_fpn
=
mobilenet_backbone
(
backbone_name
,
False
,
fpn
=
True
)
assert
isinstance
(
model_fpn
,
BackboneWithFPN
)
model
=
mobilenet_backbone
(
backbone_name
,
False
,
fpn
=
False
)
assert
isinstance
(
model
,
torch
.
nn
.
Sequential
)
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment