Unverified Commit 973db145 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Fixed typing in constructors of models submodules (#2875)

* fix: Fixed constructor typing in models._utils

* fix: Fixed constructor typing in models.alexnet

* fix: Fixed constructor typing in models.mnasnet

* fix: Fixed constructor typing in models.squeezenet
parent d4cd0bed
...@@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -41,7 +41,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
"return_layers": Dict[str, str], "return_layers": Dict[str, str],
} }
def __init__(self, model: nn.Module, return_layers: Dict[str, str]): def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]): if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model") raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers orig_return_layers = return_layers
......
...@@ -14,7 +14,7 @@ model_urls = { ...@@ -14,7 +14,7 @@ model_urls = {
class AlexNet(nn.Module): class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000): def __init__(self, num_classes: int = 1000) -> None:
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
......
...@@ -32,7 +32,7 @@ class _InvertedResidual(nn.Module): ...@@ -32,7 +32,7 @@ class _InvertedResidual(nn.Module):
stride: int, stride: int,
expansion_factor: int, expansion_factor: int,
bn_momentum: float = 0.1 bn_momentum: float = 0.1
): ) -> None:
super(_InvertedResidual, self).__init__() super(_InvertedResidual, self).__init__()
assert stride in [1, 2] assert stride in [1, 2]
assert kernel_size in [3, 5] assert kernel_size in [3, 5]
...@@ -109,7 +109,7 @@ class MNASNet(torch.nn.Module): ...@@ -109,7 +109,7 @@ class MNASNet(torch.nn.Module):
alpha: float, alpha: float,
num_classes: int = 1000, num_classes: int = 1000,
dropout: float = 0.2 dropout: float = 0.2
): ) -> None:
super(MNASNet, self).__init__() super(MNASNet, self).__init__()
assert alpha > 0.0 assert alpha > 0.0
self.alpha = alpha self.alpha = alpha
......
...@@ -20,7 +20,7 @@ class Fire(nn.Module): ...@@ -20,7 +20,7 @@ class Fire(nn.Module):
squeeze_planes: int, squeeze_planes: int,
expand1x1_planes: int, expand1x1_planes: int,
expand3x3_planes: int expand3x3_planes: int
): ) -> None:
super(Fire, self).__init__() super(Fire, self).__init__()
self.inplanes = inplanes self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
...@@ -46,7 +46,7 @@ class SqueezeNet(nn.Module): ...@@ -46,7 +46,7 @@ class SqueezeNet(nn.Module):
self, self,
version: str = '1_0', version: str = '1_0',
num_classes: int = 1000 num_classes: int = 1000
): ) -> None:
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
if version == '1_0': if version == '1_0':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment