Unverified Commit 98146a1a authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

style: Added annotation typing for alexnet (#2859)

parent 3756b607
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Any
__all__ = ['AlexNet', 'alexnet']
......@@ -13,7 +14,7 @@ model_urls = {
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
def __init__(self, num_classes: int = 1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
......@@ -41,7 +42,7 @@ class AlexNet(nn.Module):
nn.Linear(4096, num_classes),
)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
......@@ -49,7 +50,7 @@ class AlexNet(nn.Module):
return x
def alexnet(pretrained=False, progress=True, **kwargs):
def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet:
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment