Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2cd25c1a
Unverified
Commit
2cd25c1a
authored
Feb 06, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 06, 2023
Browse files
Fix resnet_fpn_backbone(pretrained=True) (#7172)
parent
135a0f9e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
7 deletions
+12
-7
test/test_extended_models.py
test/test_extended_models.py
+6
-1
torchvision/models/_api.py
torchvision/models/_api.py
+4
-4
torchvision/models/detection/backbone_utils.py
torchvision/models/detection/backbone_utils.py
+2
-2
No files found.
test/test_extended_models.py
View file @
2cd25c1a
...
...
@@ -9,6 +9,7 @@ from common_extended_utils import get_file_size_mb, get_ops
from
torchvision
import
models
from
torchvision.models
import
get_model_weights
,
Weights
,
WeightsEnum
from
torchvision.models._utils
import
handle_legacy_interface
from
torchvision.models.detection.backbone_utils
import
mobilenet_backbone
,
resnet_fpn_backbone
run_if_test_with_extended
=
pytest
.
mark
.
skipif
(
os
.
getenv
(
"PYTORCH_TEST_WITH_EXTENDED"
,
"0"
)
!=
"1"
,
...
...
@@ -425,7 +426,11 @@ class TestHandleLegacyInterface:
+
TM
.
list_model_fns
(
models
.
quantization
)
+
TM
.
list_model_fns
(
models
.
segmentation
)
+
TM
.
list_model_fns
(
models
.
video
)
+
TM
.
list_model_fns
(
models
.
optical_flow
),
+
TM
.
list_model_fns
(
models
.
optical_flow
)
+
[
lambda
pretrained
:
resnet_fpn_backbone
(
backbone_name
=
"resnet50"
,
pretrained
=
pretrained
),
lambda
pretrained
:
mobilenet_backbone
(
backbone_name
=
"mobilenet_v2"
,
fpn
=
False
,
pretrained
=
pretrained
),
],
)
@
run_if_test_with_extended
def
test_pretrained_deprecation
(
self
,
model_fn
):
...
...
torchvision/models/_api.py
View file @
2cd25c1a
...
...
@@ -6,7 +6,7 @@ from enum import Enum
from
functools
import
partial
from
inspect
import
signature
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Mapping
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Type
,
TypeVar
,
Union
from
torch
import
nn
...
...
@@ -138,7 +138,7 @@ def get_weight(name: str) -> WeightsEnum:
return
weights_enum
[
value_name
]
def
get_model_weights
(
name
:
Union
[
Callable
,
str
])
->
WeightsEnum
:
def
get_model_weights
(
name
:
Union
[
Callable
,
str
])
->
Type
[
WeightsEnum
]
:
"""
Returns the weights enum class associated to the given model.
...
...
@@ -152,7 +152,7 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
return
_get_enum_from_fn
(
model
)
def
_get_enum_from_fn
(
fn
:
Callable
)
->
WeightsEnum
:
def
_get_enum_from_fn
(
fn
:
Callable
)
->
Type
[
WeightsEnum
]
:
"""
Internal method that gets the weight enum of a specific model builder method.
...
...
@@ -182,7 +182,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
return
cast
(
WeightsEnum
,
weights_enum
)
return
weights_enum
M
=
TypeVar
(
"M"
,
bound
=
nn
.
Module
)
...
...
torchvision/models/detection/backbone_utils.py
View file @
2cd25c1a
...
...
@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
lambda
kwargs
:
_get_enum_from_fn
(
resnet
.
__dict__
[
kwargs
[
"backbone_name"
]])
.
from_str
(
"IMAGENET1K_V1"
)
,
lambda
kwargs
:
_get_enum_from_fn
(
resnet
.
__dict__
[
kwargs
[
"backbone_name"
]])
[
"IMAGENET1K_V1"
]
,
),
)
def
resnet_fpn_backbone
(
...
...
@@ -177,7 +177,7 @@ def _validate_trainable_layers(
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
lambda
kwargs
:
_get_enum_from_fn
(
mobilenet
.
__dict__
[
kwargs
[
"backbone_name"
]])
.
from_str
(
"IMAGENET1K_V1"
)
,
lambda
kwargs
:
_get_enum_from_fn
(
mobilenet
.
__dict__
[
kwargs
[
"backbone_name"
]])
[
"IMAGENET1K_V1"
]
,
),
)
def
mobilenet_backbone
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment