Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
cac4e228
Unverified
Commit
cac4e228
authored
Sep 12, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Sep 12, 2022
Browse files
Make get_model_builder public (#6560)
parent
a67cc87a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
7 deletions
+33
-7
test/test_extended_models.py
test/test_extended_models.py
+15
-0
test/test_models.py
test/test_models.py
+2
-2
torchvision/models/__init__.py
torchvision/models/__init__.py
+1
-1
torchvision/models/_api.py
torchvision/models/_api.py
+15
-4
No files found.
test/test_extended_models.py
View file @
cac4e228
...
...
@@ -29,6 +29,21 @@ def test_get_model(name, model_class):
assert
isinstance
(
models
.
get_model
(
name
),
model_class
)
@
pytest
.
mark
.
parametrize
(
"name, model_fn"
,
[
(
"resnet50"
,
models
.
resnet50
),
(
"retinanet_resnet50_fpn_v2"
,
models
.
detection
.
retinanet_resnet50_fpn_v2
),
(
"raft_large"
,
models
.
optical_flow
.
raft_large
),
(
"quantized_resnet50"
,
models
.
quantization
.
resnet50
),
(
"lraspp_mobilenet_v3_large"
,
models
.
segmentation
.
lraspp_mobilenet_v3_large
),
(
"mvit_v1_b"
,
models
.
video
.
mvit_v1_b
),
],
)
def
test_get_model_builder
(
name
,
model_fn
):
assert
models
.
get_model_builder
(
name
)
==
model_fn
@
pytest
.
mark
.
parametrize
(
"name, weight"
,
[
...
...
test/test_models.py
View file @
cac4e228
...
...
@@ -17,7 +17,7 @@ import torch.nn as nn
from
_utils_internal
import
get_relative_path
from
common_utils
import
cpu_and_gpu
,
freeze_rng_state
,
map_nested_tensor_object
,
needs_cuda
,
set_rng_seed
from
torchvision
import
models
from
torchvision.models
._api
import
find
_model
,
list_models
from
torchvision.models
import
get
_model
_builder
,
list_models
ACCEPT
=
os
.
getenv
(
"EXPECTTEST_ACCEPT"
,
"0"
)
==
"1"
...
...
@@ -25,7 +25,7 @@ SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
def
list_model_fns
(
module
):
return
[
find
_model
(
name
)
for
name
in
list_models
(
module
)]
return
[
get
_model
_builder
(
name
)
for
name
in
list_models
(
module
)]
@
pytest
.
fixture
...
...
torchvision/models/__init__.py
View file @
cac4e228
...
...
@@ -14,4 +14,4 @@ from .vgg import *
from
.vision_transformer
import
*
from
.swin_transformer
import
*
from
.
import
detection
,
optical_flow
,
quantization
,
segmentation
,
video
from
._api
import
get_model
,
get_model_weights
,
get_weight
,
list_models
from
._api
import
get_model
,
get_model_builder
,
get_model_weights
,
get_weight
,
list_models
torchvision/models/_api.py
View file @
cac4e228
...
...
@@ -13,7 +13,7 @@ from torchvision._utils import StrEnum
from
.._internally_replaced_utils
import
load_state_dict_from_url
__all__
=
[
"WeightsEnum"
,
"Weights"
,
"get_model"
,
"get_model_weights"
,
"get_weight"
,
"list_models"
]
__all__
=
[
"WeightsEnum"
,
"Weights"
,
"get_model"
,
"get_model_builder"
,
"get_model_weights"
,
"get_weight"
,
"list_models"
]
@
dataclass
...
...
@@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W:
Returns:
weights_enum (W): The weights enum class associated with the model.
"""
model
=
find
_model
(
name
)
if
isinstance
(
name
,
str
)
else
name
model
=
get
_model
_builder
(
name
)
if
isinstance
(
name
,
str
)
else
name
return
cast
(
W
,
_get_enum_from_fn
(
model
))
...
...
@@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return
sorted
(
models
)
def
find_model
(
name
:
str
)
->
Callable
[...,
M
]:
def
get_model_builder
(
name
:
str
)
->
Callable
[...,
M
]:
"""
Gets the model name and returns the model builder method.
.. betastatus:: function
Args:
name (str): The name under which the model is registered.
Returns:
fn (Callable): The model builder method.
"""
name
=
name
.
lower
()
try
:
fn
=
BUILTIN_MODELS
[
name
]
...
...
@@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M:
Returns:
model (nn.Module): The initialized model.
"""
fn
=
find
_model
(
name
)
fn
=
get
_model
_builder
(
name
)
return
fn
(
**
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment