Unverified Commit 2d927283 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix mypy errors after the 0.981 release (#6652)

parent 55a436cb
......@@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name)
W = TypeVar("W", bound=WeightsEnum)
def get_model_weights(name: Union[Callable, str]) -> W:
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
"""
Retuns the weights enum class associated to the given model.
......@@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W:
name (callable or str): The model builder function or the name under which it is registered.
Returns:
weights_enum (W): The weights enum class associated with the model.
weights_enum (WeightsEnum): The weights enum class associated with the model.
"""
model = get_model_builder(name) if isinstance(name, str) else name
return cast(W, _get_enum_from_fn(model))
return _get_enum_from_fn(model)
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
......@@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return sorted(models)
def get_model_builder(name: str) -> Callable[..., M]:
def get_model_builder(name: str) -> Callable[..., nn.Module]:
"""
Gets the model name and returns the model builder method.
......@@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]:
return fn
def get_model(name: str, **config: Any) -> M:
def get_model(name: str, **config: Any) -> nn.Module:
"""
Gets the model name and configuration and returns an instantiated model.
......
......@@ -2,7 +2,7 @@ import csv
import functools
import pathlib
import pickle
from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
import torch
import torch.distributed as dist
......@@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
return obj
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
return cast(D, _getattr_closure(path, attrs=name.split(".")))
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
return _getattr_closure(path, attrs=name.split("."))
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment