Unverified Commit 1187d363 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

adding alexnet prototype model (#4670)



* adding alexnet prototype model

* adding recipe reference

* fixing lint
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3f6ff20b
from .resnet import * from .resnet import *
from .alexnet import *
from . import detection from . import detection
from . import quantization from . import quantization
import warnings
from functools import partial
from typing import Any, Optional
from ...models.alexnet import AlexNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
class AlexNetWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 56.522,
"acc@5": 79.066,
},
)
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = AlexNetWeights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = AlexNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment