"docs/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "a576c7612d85d24780d26f382a046ab45d2b1bf7"
Unverified Commit a465fb41 authored by kira's avatar kira Committed by GitHub
Browse files

Add type hints in alexnet.py (#1983)



* Update alexnet.py

* Update alexnet.py

* Update mmcv/cnn/alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update alexnet.py

* Update mmcv/cnn/alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update alexnet.py

* fix importing format

* Update alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 699398ad
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import Optional
import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,7 +13,7 @@ class AlexNet(nn.Module): ...@@ -11,7 +13,7 @@ class AlexNet(nn.Module):
num_classes (int): number of classes for classification. num_classes (int): number of classes for classification.
""" """
def __init__(self, num_classes=-1): def __init__(self, num_classes: int = -1):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.features = nn.Sequential( self.features = nn.Sequential(
...@@ -40,7 +42,7 @@ class AlexNet(nn.Module): ...@@ -40,7 +42,7 @@ class AlexNet(nn.Module):
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
) )
def init_weights(self, pretrained=None): def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint from ..runner import load_checkpoint
...@@ -51,7 +53,7 @@ class AlexNet(nn.Module): ...@@ -51,7 +53,7 @@ class AlexNet(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x) x = self.features(x)
if self.num_classes > 0: if self.num_classes > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment