Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2d927283
Unverified
Commit
2d927283
authored
Sep 28, 2022
by
Philip Meier
Committed by
GitHub
Sep 28, 2022
Browse files
fix mypy errors after the 0.981 release (#6652)
parent
55a436cb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
11 deletions
+8
-11
torchvision/models/_api.py
torchvision/models/_api.py
+5
-8
torchvision/prototype/datasets/utils/_internal.py
torchvision/prototype/datasets/utils/_internal.py
+3
-3
No files found.
torchvision/models/_api.py
View file @
2d927283
...
...
@@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum:
return
weights_enum
.
from_str
(
value_name
)
W
=
TypeVar
(
"W"
,
bound
=
WeightsEnum
)
def
get_model_weights
(
name
:
Union
[
Callable
,
str
])
->
W
:
def
get_model_weights
(
name
:
Union
[
Callable
,
str
])
->
WeightsEnum
:
"""
Retuns the weights enum class associated to the given model.
...
...
@@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W:
name (callable or str): The model builder function or the name under which it is registered.
Returns:
weights_enum (W): The weights enum class associated with the model.
weights_enum (W
eightsEnum
): The weights enum class associated with the model.
"""
model
=
get_model_builder
(
name
)
if
isinstance
(
name
,
str
)
else
name
return
cast
(
W
,
_get_enum_from_fn
(
model
)
)
return
_get_enum_from_fn
(
model
)
def
_get_enum_from_fn
(
fn
:
Callable
)
->
WeightsEnum
:
...
...
@@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return
sorted
(
models
)
def
get_model_builder
(
name
:
str
)
->
Callable
[...,
M
]:
def
get_model_builder
(
name
:
str
)
->
Callable
[...,
nn
.
Module
]:
"""
Gets the model name and returns the model builder method.
...
...
@@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]:
return
fn
def
get_model
(
name
:
str
,
**
config
:
Any
)
->
M
:
def
get_model
(
name
:
str
,
**
config
:
Any
)
->
nn
.
Module
:
"""
Gets the model name and configuration and returns an instantiated model.
...
...
torchvision/prototype/datasets/utils/_internal.py
View file @
2d927283
...
...
@@ -2,7 +2,7 @@ import csv
import
functools
import
pathlib
import
pickle
from
typing
import
Any
,
BinaryIO
,
Callable
,
cast
,
Dict
,
IO
,
Iterator
,
List
,
Sequence
,
Sized
,
Tuple
,
TypeVar
,
Union
from
typing
import
Any
,
BinaryIO
,
Callable
,
Dict
,
IO
,
Iterator
,
List
,
Sequence
,
Sized
,
Tuple
,
TypeVar
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
return
obj
def
_path_attribute_accessor
(
path
:
pathlib
.
Path
,
*
,
name
:
str
)
->
D
:
return
cast
(
D
,
_getattr_closure
(
path
,
attrs
=
name
.
split
(
"."
))
)
def
_path_attribute_accessor
(
path
:
pathlib
.
Path
,
*
,
name
:
str
)
->
Any
:
return
_getattr_closure
(
path
,
attrs
=
name
.
split
(
"."
))
def
_path_accessor_closure
(
data
:
Tuple
[
str
,
Any
],
*
,
getter
:
Callable
[[
pathlib
.
Path
],
D
])
->
D
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment