"...text-generation-inference.git" did not exist on "55bd4fed7da83a566dca08b0bb29dbc5929a90eb"
Unverified Commit 5a2906cb authored by nxznm's avatar nxznm Committed by GitHub
Browse files

Add type hints in vgg.py (#2050)

* add type hints in vgg.py

* Update mmcv/cnn/vgg.py

* add type hints for return value in vgg.py
parent 15495ea0
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import List, Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from .utils import constant_init, kaiming_init, normal_init from .utils import constant_init, kaiming_init, normal_init
def conv3x3(in_planes, out_planes, dilation=1): def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module:
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
...@@ -16,12 +18,12 @@ def conv3x3(in_planes, out_planes, dilation=1): ...@@ -16,12 +18,12 @@ def conv3x3(in_planes, out_planes, dilation=1):
dilation=dilation) dilation=dilation)
def make_vgg_layer(inplanes, def make_vgg_layer(inplanes: int,
planes, planes: int,
num_blocks, num_blocks: int,
dilation=1, dilation: int = 1,
with_bn=False, with_bn: bool = False,
ceil_mode=False): ceil_mode: bool = False) -> List[nn.Module]:
layers = [] layers = []
for _ in range(num_blocks): for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation)) layers.append(conv3x3(inplanes, planes, dilation))
...@@ -59,17 +61,17 @@ class VGG(nn.Module): ...@@ -59,17 +61,17 @@ class VGG(nn.Module):
} }
def __init__(self, def __init__(self,
depth, depth: int,
with_bn=False, with_bn: bool = False,
num_classes=-1, num_classes: int = -1,
num_stages=5, num_stages: int = 5,
dilations=(1, 1, 1, 1, 1), dilations: Sequence[int] = (1, 1, 1, 1, 1),
out_indices=(0, 1, 2, 3, 4), out_indices: Sequence[int] = (0, 1, 2, 3, 4),
frozen_stages=-1, frozen_stages: int = -1,
bn_eval=True, bn_eval: bool = True,
bn_frozen=False, bn_frozen: bool = False,
ceil_mode=False, ceil_mode: bool = False,
with_last_pool=True): with_last_pool: bool = True):
super().__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg') raise KeyError(f'invalid depth {depth} for vgg')
...@@ -122,7 +124,7 @@ class VGG(nn.Module): ...@@ -122,7 +124,7 @@ class VGG(nn.Module):
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
) )
def init_weights(self, pretrained=None): def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint from ..runner import load_checkpoint
...@@ -138,7 +140,7 @@ class VGG(nn.Module): ...@@ -138,7 +140,7 @@ class VGG(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, ...]]:
outs = [] outs = []
vgg_layers = getattr(self, self.module_name) vgg_layers = getattr(self, self.module_name)
for i in range(len(self.stage_blocks)): for i in range(len(self.stage_blocks)):
...@@ -156,7 +158,7 @@ class VGG(nn.Module): ...@@ -156,7 +158,7 @@ class VGG(nn.Module):
else: else:
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode: bool = True) -> None:
super().train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment