Unverified Commit 11bd2eaa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port Multi-weight support from prototype to main (#5618)



* Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet.

* Porting googlenet

* Porting inception

* Porting mnasnet

* Porting mobilenetv2

* Porting mobilenetv3

* Porting regnet

* Porting resnet

* Porting shufflenetv2

* Porting squeezenet

* Porting vgg

* Porting vit

* Fix docstrings

* Fixing imports

* Adding missing import

* Fix mobilenet imports

* Fix tests

* Fix prototype tests

* Exclude get_weight from models on test

* Fix init files

* Porting googlenet

* Porting inception

* porting mobilenetv2

* porting mobilenetv3

* porting resnet

* porting shufflenetv2

* Fix test and linter

* Fixing docs.

* Porting Detection models (#5617)

* fix inits

* fix docs

* Port faster_rcnn

* Port fcos

* Port keypoint_rcnn

* Port mask_rcnn

* Port retinanet

* Port ssd

* Port ssdlite

* Fix linter

* Fixing tests

* Fixing tests

* Fixing vgg test

* Porting Optical Flow, Segmentation, Video models (#5619)

* Porting raft

* Porting video resnet

* Porting deeplabv3

* Porting fcn and lraspp

* Fixing the tests and linter

* Porting docs, examples, tutorials and galleries (#5620)

* Fix examples, tutorials and gallery

* Update gallery/plot_optical_flow.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Fix import

* Revert hardcoded normalization

* fix uncommitted changes

* Fix bug

* Fix more bugs

* Making resize optional for segmentation

* Fixing preset

* Fix mypy

* Fixing documentation strings

* Fix flake8

* minor refactoring
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Resolve conflict

* Porting model tests (#5622)

* Porting tests

* Remove unnecessary variable

* Fix linter

* Move prototype to extended tests

* Fix download models job

* Update CI on Multiweight branch to use the new weight download approach (#5628)

* port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal

* Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624)
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* pre-download model weights in CI docs build (#5625)

* pre-download model weights in CI docs build

* move changes into template

* change docs image

* Regenerated config.yml
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>

* Porting reference scripts and updating presets (#5629)

* Making _preset.py classes

* Remove support of targets on presets.

* Rewriting the video preset

* Adding tests to check that the bundled transforms are JIT scriptable

* Rename all presets from *Eval to *Inference

* Minor refactoring

* Remove --prototype and --pretrained from reference scripts

* remove  pretained_backbone refs

* Corrections and simplifications

* Fixing bug

* Fixing linter

* Fix flake8

* restore documentation example

* minor fixes

* fix optical flow missing param

* Fixing commands

* Adding weights_backbone support in detection and segmentation

* Updating the commands for InceptionV3

* Setting `weights_backbone` to its fully BC value (#5653)

* Replace default `weights_backbone=None` with its BC values.

* Fixing tests

* Fix linter

* Update docs.

* Update preprocessing on reference scripts.

* Change qat/ptq to their full values.

* Refactoring preprocessing

* Fix video preset

* No initialization on VGG if pretrained

* Fix warning messages for backbone utils.

* Adding star to all preset constructors.

* Fix mypy.
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
parent 375e4ab2
from . import datasets from . import datasets
from . import features from . import features
from . import models
from . import transforms from . import transforms
from . import utils from . import utils
from .alexnet import *
from .convnext import *
from .densenet import *
from .efficientnet import *
from .googlenet import *
from .inception import *
from .mnasnet import *
from .mobilenet import *
from .regnet import *
from .resnet import *
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from . import detection
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
from ._api import get_weight
import functools
import warnings
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
from torch import nn
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw
from ._api import WeightsEnum
W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V")
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
"""Decorates a model builder with the new interface to make it compatible with the old.
In particular this handles two things:
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or (isinstance(weights_arg, str) and weights_arg != "legacy")
or weights_arg is None
):
continue
# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else:
pretrained_arg = kwargs[pretrained_param]
if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None
if not pretrained_positional:
warnings.warn(
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
)
msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)
del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg
return builder(*args, **kwargs)
return inner_wrapper
return outer_wrapper
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value
def _ovewrite_value_param(param: Optional[V], new_value: V) -> V:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.")
return new_value
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class AlexNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "AlexNet",
"publication_year": 2012,
"num_params": 61100840,
"size": (224, 224),
"min_size": (63, 63),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 56.522,
"acc@5": 79.066,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = AlexNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, List, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.convnext import ConvNeXt, CNBlockConfig
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"ConvNeXt",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Base_Weights",
"ConvNeXt_Large_Weights",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
]
def _convnext(
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ConvNeXt:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "ConvNeXt",
"publication_year": 2022,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
}
class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
"num_params": 28589128,
"acc@1": 82.520,
"acc@5": 96.146,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
"acc@1": 83.616,
"acc@5": 96.650,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
"acc@1": 84.062,
"acc@5": 96.870,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
"acc@1": 84.414,
"acc@5": 96.976,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
weights = ConvNeXt_Tiny_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 9),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
weights = ConvNeXt_Small_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
weights = ConvNeXt_Base_Weights.verify(weights)
block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
weights = ConvNeXt_Large_Weights.verify(weights)
block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
import re
from functools import partial
from typing import Any, Optional, Tuple
import torch.nn as nn
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"DenseNet",
"DenseNet121_Weights",
"DenseNet161_Weights",
"DenseNet169_Weights",
"DenseNet201_Weights",
"densenet121",
"densenet161",
"densenet169",
"densenet201",
]
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
state_dict = weights.get_state_dict(progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
def _densenet(
growth_rate: int,
block_config: Tuple[int, int, int, int],
num_init_features: int,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> DenseNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if weights is not None:
_load_state_dict(model=model, weights=weights, progress=progress)
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "DenseNet",
"publication_year": 2016,
"size": (224, 224),
"min_size": (29, 29),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/pull/116",
}
class DenseNet121_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7978856,
"acc@1": 74.434,
"acc@5": 91.972,
},
)
DEFAULT = IMAGENET1K_V1
class DenseNet161_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 28681000,
"acc@1": 77.138,
"acc@5": 93.560,
},
)
DEFAULT = IMAGENET1K_V1
class DenseNet169_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 14149480,
"acc@1": 75.600,
"acc@5": 92.806,
},
)
DEFAULT = IMAGENET1K_V1
class DenseNet201_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 20013928,
"acc@1": 76.896,
"acc@5": 93.370,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
from .faster_rcnn import *
from .fcos import *
from .keypoint_rcnn import *
from .mask_rcnn import *
from .retinanet import *
from .ssd import *
from .ssdlite import *
from typing import Any, Optional, Union
from torch import nn
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.faster_rcnn import (
_mobilenet_extractor,
_resnet_fpn_extractor,
_validate_trainable_layers,
AnchorGenerator,
FasterRCNN,
misc_nn_ops,
overwrite_eps,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"FasterRCNN",
"FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]
_COMMON_META = {
"task": "image_object_detection",
"architecture": "FasterRCNN",
"publication_year": 2015,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 41755286,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)
DEFAULT = COCO_V1
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
)
DEFAULT = COCO_V1
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fasterrcnn_resnet50_fpn(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
def _fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
progress: bool,
num_classes: Optional[int],
weights_backbone: Optional[MobileNet_V3_Large_Weights],
trainable_backbone_layers: Optional[int],
**kwargs: Any,
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = (
(
32,
64,
128,
256,
512,
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = {
"rpn_score_thresh": 0.05,
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights=weights,
progress=progress,
num_classes=num_classes,
weights_backbone=weights_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
*,
weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
defaults = {
"min_size": 320,
"max_size": 640,
"rpn_pre_nms_top_n_test": 150,
"rpn_post_nms_top_n_test": 150,
"rpn_score_thresh": 0.05,
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights=weights,
progress=progress,
num_classes=num_classes,
weights_backbone=weights_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.fcos import (
_resnet_fpn_extractor,
_validate_trainable_layers,
FCOS,
LastLevelP6P7,
misc_nn_ops,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"FCOS",
"FCOS_ResNet50_FPN_Weights",
"fcos_resnet50_fpn",
]
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "FCOS",
"publication_year": 2019,
"num_params": 32269600,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
"map": 39.2,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fcos_resnet50_fpn(
*,
weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FCOS:
weights = FCOS_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
)
model = FCOS(backbone, num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.keypoint_rcnn import (
_resnet_fpn_extractor,
_validate_trainable_layers,
KeypointRCNN,
misc_nn_ops,
overwrite_eps,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"KeypointRCNN",
"KeypointRCNN_ResNet50_FPN_Weights",
"keypointrcnn_resnet50_fpn",
]
_COMMON_META = {
"task": "image_object_detection",
"architecture": "KeypointRCNN",
"publication_year": 2017,
"categories": _COCO_PERSON_CATEGORIES,
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
"interpolation": InterpolationMode.BILINEAR,
}
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/issues/1606",
"map": 50.6,
"map_kp": 61.1,
},
)
COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"map": 54.6,
"map_kp": 65.0,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
if kwargs["pretrained"] == "legacy"
else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def keypointrcnn_resnet50_fpn(
*,
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
num_keypoints: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> KeypointRCNN:
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"]))
else:
if num_classes is None:
num_classes = 2
if num_keypoints is None:
num_keypoints = 17
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.mask_rcnn import (
_resnet_fpn_extractor,
_validate_trainable_layers,
MaskRCNN,
misc_nn_ops,
overwrite_eps,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"MaskRCNN",
"MaskRCNN_ResNet50_FPN_Weights",
"maskrcnn_resnet50_fpn",
]
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "MaskRCNN",
"publication_year": 2017,
"num_params": 44401393,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
"map": 37.9,
"map_mask": 34.6,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def maskrcnn_resnet50_fpn(
*,
weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> MaskRCNN:
weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.retinanet import (
_resnet_fpn_extractor,
_validate_trainable_layers,
RetinaNet,
LastLevelP6P7,
misc_nn_ops,
overwrite_eps,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
__all__ = [
"RetinaNet",
"RetinaNet_ResNet50_FPN_Weights",
"retinanet_resnet50_fpn",
]
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "RetinaNet",
"publication_year": 2017,
"num_params": 34014999,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def retinanet_resnet50_fpn(
*,
weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> RetinaNet:
weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
# skip P2 because it generates too many anchors (according to their paper)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
)
model = RetinaNet(backbone, num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
return model
import warnings
from typing import Any, Optional
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssd import (
_validate_trainable_layers,
_vgg_extractor,
DefaultBoxGenerator,
SSD,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..vgg import VGG16_Weights, vgg16
__all__ = [
"SSD300_VGG16_Weights",
"ssd300_vgg16",
]
class SSD300_VGG16_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "SSD",
"publication_year": 2015,
"num_params": 35641826,
"size": (300, 300),
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
)
def ssd300_vgg16(
*,
weights: Optional[SSD300_VGG16_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[VGG16_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> SSD:
weights = SSD300_VGG16_Weights.verify(weights)
weights_backbone = VGG16_Weights.verify(weights_backbone)
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
)
# Use custom backbones more appropriate for SSD
backbone = vgg16(weights=weights_backbone, progress=progress)
backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator(
[[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
steps=[8, 16, 32, 64, 100, 300],
)
defaults = {
# Rescale the input in a way compatible to the backbone
"image_mean": [0.48235, 0.45882, 0.40784],
"image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
}
kwargs: Any = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
import warnings
from functools import partial
from typing import Any, Callable, Optional
from torch import nn
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssdlite import (
_mobilenet_extractor,
_normal_init,
_validate_trainable_layers,
DefaultBoxGenerator,
det_utils,
SSD,
SSDLiteHead,
)
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = [
"SSDLite320_MobileNet_V3_Large_Weights",
"ssdlite320_mobilenet_v3_large",
]
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "SSDLite",
"publication_year": 2018,
"num_params": 3440060,
"size": (320, 320),
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3,
},
)
DEFAULT = COCO_V1
@handle_legacy_interface(
weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def ssdlite320_mobilenet_v3_large(
*,
weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> SSD:
weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if "size" in kwargs:
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6
)
# Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.
reduce_tail = weights_backbone is None
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
backbone = mobilenet_v3_large(
weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
)
if weights_backbone is None:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)
backbone = _mobilenet_extractor(
backbone,
trainable_backbone_layers,
norm_layer,
)
size = (320, 320)
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
out_channels = det_utils.retrieve_out_channels(backbone, size)
num_anchors = anchor_generator.num_anchors_per_location()
if len(out_channels) != len(anchor_generator.aspect_ratios):
raise ValueError(
f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}"
)
defaults = {
"score_thresh": 0.001,
"nms_thresh": 0.55,
"detections_per_img": 300,
"topk_candidates": 300,
# Rescale the input in a way compatible to the backbone:
# The following mean/std rescale the data from [0, 1] to [-1, -1]
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
}
kwargs: Any = {**defaults, **kwargs}
model = SSD(
backbone,
anchor_generator,
size,
num_classes,
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, Optional, Sequence, Union
from torch import nn
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"EfficientNet",
"EfficientNet_B0_Weights",
"EfficientNet_B1_Weights",
"EfficientNet_B2_Weights",
"EfficientNet_B3_Weights",
"EfficientNet_B4_Weights",
"EfficientNet_B5_Weights",
"EfficientNet_B6_Weights",
"EfficientNet_B7_Weights",
"EfficientNet_V2_S_Weights",
"EfficientNet_V2_M_Weights",
"EfficientNet_V2_L_Weights",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_s",
"efficientnet_v2_m",
"efficientnet_v2_l",
]
def _efficientnet(
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
dropout: float,
last_channel: Optional[int],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> EfficientNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
}
_COMMON_META_V1 = {
**_COMMON_META,
"architecture": "EfficientNet",
"publication_year": 2019,
"interpolation": InterpolationMode.BICUBIC,
"min_size": (1, 1),
}
_COMMON_META_V2 = {
**_COMMON_META,
"architecture": "EfficientNetV2",
"publication_year": 2021,
"interpolation": InterpolationMode.BILINEAR,
"min_size": (33, 33),
}
class EfficientNet_B0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(
ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 5288548,
"size": (224, 224),
"acc@1": 77.692,
"acc@5": 93.532,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_B1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(
ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 7794184,
"size": (240, 240),
"acc@1": 78.642,
"acc@5": 94.186,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(
ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
),
meta={
**_COMMON_META_V1,
"num_params": 7794184,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
"interpolation": InterpolationMode.BILINEAR,
"size": (240, 240),
"acc@1": 79.838,
"acc@5": 94.934,
},
)
DEFAULT = IMAGENET1K_V2
class EfficientNet_B2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(
ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 9109994,
"size": (288, 288),
"acc@1": 80.608,
"acc@5": 95.310,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_B3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(
ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 12233232,
"size": (300, 300),
"acc@1": 82.008,
"acc@5": 96.054,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_B4_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(
ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 19341616,
"size": (380, 380),
"acc@1": 83.384,
"acc@5": 96.594,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_B5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(
ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 30389784,
"size": (456, 456),
"acc@1": 83.444,
"acc@5": 96.628,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_B6_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(
ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 43040704,
"size": (528, 528),
"acc@1": 84.008,
"acc@5": 96.916,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_B7_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(
ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 66347960,
"size": (600, 600),
"acc@1": 84.122,
"acc@5": 96.908,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
transforms=partial(
ImageClassificationEval,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 21458488,
"size": (384, 384),
"acc@1": 84.228,
"acc@5": 96.878,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
transforms=partial(
ImageClassificationEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 54139356,
"size": (480, 480),
"acc@1": 85.112,
"acc@5": 97.156,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_L_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
ImageClassificationEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BICUBIC,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
meta={
**_COMMON_META_V2,
"num_params": 118515272,
"size": (480, 480),
"acc@1": 85.808,
"acc@5": 97.788,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B0_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
def efficientnet_b1(
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B1_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
def efficientnet_b2(
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B2_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
def efficientnet_b3(
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B3_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
def efficientnet_b4(
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B4_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
def efficientnet_b5(
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B5_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
return _efficientnet(
inverted_residual_setting,
0.4,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
def efficientnet_b6(
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B6_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
return _efficientnet(
inverted_residual_setting,
0.5,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
def efficientnet_b7(
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B7_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
return _efficientnet(
inverted_residual_setting,
0.5,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_S_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
return _efficientnet(
inverted_residual_setting,
0.2,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_M_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
return _efficientnet(
inverted_residual_setting,
0.3,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_L_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
return _efficientnet(
inverted_residual_setting,
0.4,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
import warnings
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
class GoogLeNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "GoogLeNet",
"publication_year": 2014,
"num_params": 6624904,
"size": (224, 224),
"min_size": (15, 15),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
"acc@1": 69.778,
"acc@5": 89.530,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
weights = GoogLeNet_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = GoogLeNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
class Inception_V3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342),
meta={
"task": "image_classification",
"architecture": "InceptionV3",
"publication_year": 2015,
"num_params": 27161264,
"size": (299, 299),
"min_size": (75, 75),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
"acc@1": 77.294,
"acc@5": 93.450,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
weights = Inception_V3_Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None:
if "transform_input" not in kwargs:
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = Inception3(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"MNASNet",
"MNASNet0_5_Weights",
"MNASNet0_75_Weights",
"MNASNet1_0_Weights",
"MNASNet1_3_Weights",
"mnasnet0_5",
"mnasnet0_75",
"mnasnet1_0",
"mnasnet1_3",
]
_COMMON_META = {
"task": "image_classification",
"architecture": "MNASNet",
"publication_year": 2018,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/1e100/mnasnet_trainer",
}
class MNASNet0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2218512,
"acc@1": 67.734,
"acc@5": 87.490,
},
)
DEFAULT = IMAGENET1K_V1
class MNASNet0_75_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet0_75
pass
class MNASNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 4383312,
"acc@1": 73.456,
"acc@5": 91.510,
},
)
DEFAULT = IMAGENET1K_V1
class MNASNet1_3_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet1_3
pass
def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MNASNet(alpha, **kwargs)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet0_5_Weights.verify(weights)
return _mnasnet(0.5, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet0_75_Weights.verify(weights)
return _mnasnet(0.75, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet1_0_Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
weights = MNASNet1_3_Weights.verify(weights)
return _mnasnet(1.3, weights, progress, **kwargs)
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all
__all__ = mv2_all + mv3_all
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
_COMMON_META = {
"task": "image_classification",
"architecture": "MobileNetV2",
"publication_year": 2018,
"num_params": 3504872,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class MobileNet_V2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
"acc@1": 71.878,
"acc@5": 90.286,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
"acc@1": 72.154,
"acc@5": 90.822,
},
)
DEFAULT = IMAGENET1K_V2
@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
def mobilenet_v2(
*, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV2:
weights = MobileNet_V2_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MobileNetV2(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment