"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0d99ae1fe84f8d191abe5ed1c2f4fdc5a9f9a773"
Unverified Commit a465fb41 authored by kira's avatar kira Committed by GitHub
Browse files

Add type hints in alexnet.py (#1983)



* Update alexnet.py

* Update alexnet.py

* Update mmcv/cnn/alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update alexnet.py

* Update mmcv/cnn/alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update alexnet.py

* fix importing format

* Update alexnet.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 699398ad
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional
import torch
import torch.nn as nn
......@@ -11,7 +13,7 @@ class AlexNet(nn.Module):
num_classes (int): number of classes for classification.
"""
def __init__(self, num_classes=-1):
def __init__(self, num_classes: int = -1):
super().__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
......@@ -40,7 +42,7 @@ class AlexNet(nn.Module):
nn.Linear(4096, num_classes),
)
def init_weights(self, pretrained=None):
def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str):
logger = logging.getLogger()
from ..runner import load_checkpoint
......@@ -51,7 +53,7 @@ class AlexNet(nn.Module):
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
if self.num_classes > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment