Unverified Commit 2d927283 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix mypy errors after the 0.981 release (#6652)

parent 55a436cb
...@@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name) return weights_enum.from_str(value_name)
W = TypeVar("W", bound=WeightsEnum) def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
def get_model_weights(name: Union[Callable, str]) -> W:
""" """
Retuns the weights enum class associated to the given model. Retuns the weights enum class associated to the given model.
...@@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W: ...@@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W:
name (callable or str): The model builder function or the name under which it is registered. name (callable or str): The model builder function or the name under which it is registered.
Returns: Returns:
weights_enum (W): The weights enum class associated with the model. weights_enum (WeightsEnum): The weights enum class associated with the model.
""" """
model = get_model_builder(name) if isinstance(name, str) else name model = get_model_builder(name) if isinstance(name, str) else name
return cast(W, _get_enum_from_fn(model)) return _get_enum_from_fn(model)
def _get_enum_from_fn(fn: Callable) -> WeightsEnum: def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
...@@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: ...@@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return sorted(models) return sorted(models)
def get_model_builder(name: str) -> Callable[..., M]: def get_model_builder(name: str) -> Callable[..., nn.Module]:
""" """
Gets the model name and returns the model builder method. Gets the model name and returns the model builder method.
...@@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]: ...@@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]:
return fn return fn
def get_model(name: str, **config: Any) -> M: def get_model(name: str, **config: Any) -> nn.Module:
""" """
Gets the model name and configuration and returns an instantiated model. Gets the model name and configuration and returns an instantiated model.
......
...@@ -2,7 +2,7 @@ import csv ...@@ -2,7 +2,7 @@ import csv
import functools import functools
import pathlib import pathlib
import pickle import pickle
from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any: ...@@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
return obj return obj
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D: def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
return cast(D, _getattr_closure(path, attrs=name.split("."))) return _getattr_closure(path, attrs=name.split("."))
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment