Unverified Commit 11bd2eaa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port Multi-weight support from prototype to main (#5618)



* Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet.

* Porting googlenet

* Porting inception

* Porting mnasnet

* Porting mobilenetv2

* Porting mobilenetv3

* Porting regnet

* Porting resnet

* Porting shufflenetv2

* Porting squeezenet

* Porting vgg

* Porting vit

* Fix docstrings

* Fixing imports

* Adding missing import

* Fix mobilenet imports

* Fix tests

* Fix prototype tests

* Exclude get_weight from models on test

* Fix init files

* Porting googlenet

* Porting inception

* porting mobilenetv2

* porting mobilenetv3

* porting resnet

* porting shufflenetv2

* Fix test and linter

* Fixing docs.

* Porting Detection models (#5617)

* fix inits

* fix docs

* Port faster_rcnn

* Port fcos

* Port keypoint_rcnn

* Port mask_rcnn

* Port retinanet

* Port ssd

* Port ssdlite

* Fix linter

* Fixing tests

* Fixing tests

* Fixing vgg test

* Porting Optical Flow, Segmentation, Video models (#5619)

* Porting raft

* Porting video resnet

* Porting deeplabv3

* Porting fcn and lraspp

* Fixing the tests and linter

* Porting docs, examples, tutorials and galleries (#5620)

* Fix examples, tutorials and gallery

* Update gallery/plot_optical_flow.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Fix import

* Revert hardcoded normalization

* fix uncommitted changes

* Fix bug

* Fix more bugs

* Making resize optional for segmentation

* Fixing preset

* Fix mypy

* Fixing documentation strings

* Fix flake8

* minor refactoring
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Resolve conflict

* Porting model tests (#5622)

* Porting tests

* Remove unnecessary variable

* Fix linter

* Move prototype to extended tests

* Fix download models job

* Update CI on Multiweight branch to use the new weight download approach (#5628)

* port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal

* Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624)
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* pre-download model weights in CI docs build (#5625)

* pre-download model weights in CI docs build

* move changes into template

* change docs image

* Regenerated config.yml
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>

* Porting reference scripts and updating presets (#5629)

* Making _preset.py classes

* Remove support of targets on presets.

* Rewriting the video preset

* Adding tests to check that the bundled transforms are JIT scriptable

* Rename all presets from *Eval to *Inference

* Minor refactoring

* Remove --prototype and --pretrained from reference scripts

* remove  pretained_backbone refs

* Corrections and simplifications

* Fixing bug

* Fixing linter

* Fix flake8

* restore documentation example

* minor fixes

* fix optical flow missing param

* Fixing commands

* Adding weights_backbone support in detection and segmentation

* Updating the commands for InceptionV3

* Setting `weights_backbone` to its fully BC value (#5653)

* Replace default `weights_backbone=None` with its BC values.

* Fixing tests

* Fix linter

* Update docs.

* Update preprocessing on reference scripts.

* Change qat/ptq to their full values.

* Refactoring preprocessing

* Fix video preset

* No initialization on VGG if pretrained

* Fix warning messages for backbone utils.

* Adding star to all preset constructors.

* Fix mypy.
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
parent 375e4ab2
from functools import partial
from typing import Any, Optional, List
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"MobileNetV3",
"MobileNet_V3_Large_Weights",
"MobileNet_V3_Small_Weights",
"mobilenet_v3_large",
"mobilenet_v3_small",
]
def _mobilenet_v3(
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> MobileNetV3:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "MobileNetV3",
"publication_year": 2019,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class MobileNet_V3_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 5483032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 74.042,
"acc@5": 91.340,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 5483032,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
"acc@1": 75.274,
"acc@5": 92.566,
},
)
DEFAULT = IMAGENET1K_V2
class MobileNet_V3_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2542856,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 67.668,
"acc@5": 87.402,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
def mobilenet_v3_large(
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
weights = MobileNet_V3_Large_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
def mobilenet_v3_small(
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV3:
weights = MobileNet_V3_Small_Weights.verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights
from typing import Optional
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
from torchvision.prototype.transforms import OpticalFlowEval
from torchvision.transforms.functional import InterpolationMode
from .._api import WeightsEnum
from .._api import Weights
from .._utils import handle_legacy_interface
__all__ = (
"RAFT",
"raft_large",
"raft_small",
"Raft_Large_Weights",
"Raft_Small_Weights",
)
_COMMON_META = {
"task": "optical_flow",
"architecture": "RAFT",
"publication_year": 2020,
"interpolation": InterpolationMode.BILINEAR,
}
class Raft_Large_Weights(WeightsEnum):
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 1.4411,
"sintel_train_finalpass_epe": 2.7894,
"kitti_train_per_image_epe": 5.0172,
"kitti_train_f1-all": 17.4506,
},
)
C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.3822,
"sintel_train_finalpass_epe": 2.7161,
"kitti_train_per_image_epe": 4.5118,
"kitti_train_f1-all": 16.0679,
},
)
C_T_SKHT_V1 = Weights(
# Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_test_cleanpass_epe": 1.94,
"sintel_test_finalpass_epe": 3.18,
},
)
C_T_SKHT_V2 = Weights(
# Chairs + Things + Sintel fine-tuning, i.e.:
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_test_cleanpass_epe": 1.819,
"sintel_test_finalpass_epe": 3.067,
},
)
C_T_SKHT_K_V1 = Weights(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT",
"kitti_test_f1-all": 5.10,
},
)
C_T_SKHT_K_V2 = Weights(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
# Same as CT_SKHT with extra fine-tuning on Kitti
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"kitti_test_f1-all": 5.19,
},
)
DEFAULT = C_T_SKHT_V2
class Raft_Small_Weights(WeightsEnum):
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-small.pth)
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 990162,
"recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 2.1231,
"sintel_train_finalpass_epe": 3.2790,
"kitti_train_per_image_epe": 7.6557,
"kitti_train_f1-all": 25.2801,
},
)
C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 990162,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.9901,
"sintel_train_finalpass_epe": 3.2831,
"kitti_train_per_image_epe": 7.5978,
"kitti_train_f1-all": 25.2369,
},
)
DEFAULT = C_T_V2
@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
weights(Raft_Large_weights, optional): pretrained weights to use.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Returns:
nn.Module: The model.
"""
weights = Raft_Large_Weights.verify(weights)
model = _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_block=ResidualBlock,
context_encoder_norm_layer=BatchNorm2d,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=4,
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Mask predictor
use_mask_predictor=True,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
Args:
weights(Raft_Small_weights, optional): pretrained weights to use.
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.
Returns:
nn.Module: The model.
"""
weights = Raft_Small_Weights.verify(weights)
model = _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(32, 32, 64, 96, 160),
context_encoder_block=BottleneckBlock,
context_encoder_norm_layer=None,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=3,
# Motion encoder
motion_encoder_corr_layers=(96,),
motion_encoder_flow_layers=(64, 32),
motion_encoder_out_channels=82,
# Recurrent block
recurrent_block_hidden_state_size=96,
recurrent_block_kernel_size=(3,),
recurrent_block_padding=(1,),
# Flow head
flow_head_hidden_size=128,
# Mask predictor
use_mask_predictor=False,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from .googlenet import *
from .inception import *
from .mobilenet import *
from .resnet import *
from .shufflenetv2 import *
import warnings
from functools import partial
from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.googlenet import (
QuantizableGoogLeNet,
_replace_relu,
quantize_model,
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..googlenet import GoogLeNet_Weights
__all__ = [
"QuantizableGoogLeNet",
"GoogLeNet_QuantizedWeights",
"googlenet",
]
class GoogLeNet_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "GoogLeNet",
"publication_year": 2014,
"num_params": 6624904,
"size": (224, 224),
"min_size": (15, 15),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNet_Weights.IMAGENET1K_V1,
"acc@1": 69.826,
"acc@5": 89.404,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else GoogLeNet_Weights.IMAGENET1K_V1,
)
)
def googlenet(
*,
weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableGoogLeNet:
weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model
from functools import partial
from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.inception import (
QuantizableInception3,
_replace_relu,
quantize_model,
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..inception import Inception_V3_Weights
__all__ = [
"QuantizableInception3",
"Inception_V3_QuantizedWeights",
"inception_v3",
]
class Inception_V3_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342),
meta={
"task": "image_classification",
"architecture": "InceptionV3",
"publication_year": 2015,
"num_params": 27161264,
"size": (299, 299),
"min_size": (75, 75),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": Inception_V3_Weights.IMAGENET1K_V1,
"acc@1": 77.176,
"acc@5": 93.354,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else Inception_V3_Weights.IMAGENET1K_V1,
)
)
def inception_v3(
*,
weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableInception3:
weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.get_state_dict(progress=progress))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all
__all__ = mv2_all + mv3_all
from functools import partial
from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv2 import (
QuantizableInvertedResidual,
QuantizableMobileNetV2,
_replace_relu,
quantize_model,
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv2 import MobileNet_V2_Weights
__all__ = [
"QuantizableMobileNetV2",
"MobileNet_V2_QuantizedWeights",
"mobilenet_v2",
]
class MobileNet_V2_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"publication_year": 2018,
"num_params": 3504872,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNet_V2_Weights.IMAGENET1K_V1,
"acc@1": 71.658,
"acc@5": 90.150,
},
)
DEFAULT = IMAGENET1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V2_Weights.IMAGENET1K_V1,
)
)
def mobilenet_v2(
*,
weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, List, Optional, Union
import torch
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv3 import (
InvertedResidualConfig,
QuantizableInvertedResidual,
QuantizableMobileNetV3,
_replace_relu,
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf
__all__ = [
"QuantizableMobileNetV3",
"MobileNet_V3_Large_QuantizedWeights",
"mobilenet_v3_large",
]
def _mobilenet_v3_model(
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if quantize:
torch.ao.quantization.convert(model, inplace=True)
model.eval()
return model
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV3",
"publication_year": 2019,
"num_params": 5483032,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
"acc@1": 73.004,
"acc@5": 90.858,
},
)
DEFAULT = IMAGENET1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V3_Large_Weights.IMAGENET1K_V1,
)
)
def mobilenet_v3_large(
*,
weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
from functools import partial
from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.resnet import (
QuantizableBasicBlock,
QuantizableBottleneck,
QuantizableResNet,
_replace_relu,
quantize_model,
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
__all__ = [
"QuantizableResNet",
"ResNet18_QuantizedWeights",
"ResNet50_QuantizedWeights",
"ResNeXt101_32X8D_QuantizedWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
]
def _resnet(
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
layers: List[int],
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
class ResNet18_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 11689512,
"unquantized": ResNet18_Weights.IMAGENET1K_V1,
"acc@1": 69.494,
"acc@5": 88.882,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
class ResNet50_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V1,
"acc@1": 75.920,
"acc@5": 92.814,
},
)
IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V2,
"acc@1": 80.282,
"acc@5": 94.976,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V2
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
"acc@1": 78.986,
"acc@5": 94.480,
},
)
IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
"acc@1": 82.574,
"acc@5": 96.132,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V2
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet18_Weights.IMAGENET1K_V1,
)
)
def resnet18(
*,
weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet50_Weights.IMAGENET1K_V1,
)
)
def resnet50(
*,
weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
)
)
def resnext101_32x8d(
*,
weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
from functools import partial
from typing import Any, List, Optional, Union
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.shufflenetv2 import (
QuantizableShuffleNetV2,
_replace_relu,
quantize_model,
)
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
__all__ = [
"QuantizableShuffleNetV2",
"ShuffleNet_V2_X0_5_QuantizedWeights",
"ShuffleNet_V2_X1_0_QuantizedWeights",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
]
def _shufflenetv2(
stages_repeats: List[int],
stages_out_channels: List[int],
*,
weights: Optional[WeightsEnum],
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"publication_year": 2018,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "ptq",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 1366792,
"unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
"acc@1": 57.972,
"acc@5": 79.780,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2278604,
"unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
"acc@1": 68.360,
"acc@5": 87.582,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
)
)
def shufflenet_v2_x0_5(
*,
weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
return _shufflenetv2(
[4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
)
)
def shufflenet_v2_x1_0(
*,
weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
return _shufflenetv2(
[4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
)
from functools import partial
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"RegNet",
"RegNet_Y_400MF_Weights",
"RegNet_Y_800MF_Weights",
"RegNet_Y_1_6GF_Weights",
"RegNet_Y_3_2GF_Weights",
"RegNet_Y_8GF_Weights",
"RegNet_Y_16GF_Weights",
"RegNet_Y_32GF_Weights",
"RegNet_Y_128GF_Weights",
"RegNet_X_400MF_Weights",
"RegNet_X_800MF_Weights",
"RegNet_X_1_6GF_Weights",
"RegNet_X_3_2GF_Weights",
"RegNet_X_8GF_Weights",
"RegNet_X_16GF_Weights",
"RegNet_X_32GF_Weights",
"regnet_y_400mf",
"regnet_y_800mf",
"regnet_y_1_6gf",
"regnet_y_3_2gf",
"regnet_y_8gf",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_1_6gf",
"regnet_x_3_2gf",
"regnet_x_8gf",
"regnet_x_16gf",
"regnet_x_32gf",
]
_COMMON_META = {
"task": "image_classification",
"architecture": "RegNet",
"publication_year": 2020,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
def _regnet(
block_params: BlockParams,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> RegNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
class RegNet_Y_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 4344144,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 74.046,
"acc@5": 91.716,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 4344144,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 75.804,
"acc@5": 92.742,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 6432512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 76.420,
"acc@5": 93.136,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 6432512,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 78.828,
"acc@5": 94.502,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11202430,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.950,
"acc@5": 93.966,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 11202430,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 80.876,
"acc@5": 95.444,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 19436338,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.948,
"acc@5": 94.576,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 19436338,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.982,
"acc@5": 95.972,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 39381472,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.032,
"acc@5": 95.048,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 39381472,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.828,
"acc@5": 96.330,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 83590140,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.424,
"acc@5": 95.240,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 83590140,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.886,
"acc@5": 96.328,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 145046770,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.878,
"acc@5": 95.340,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 145046770,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 83.368,
"acc@5": 96.498,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum):
# weights are not available yet.
pass
class RegNet_X_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 5495976,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 72.834,
"acc@5": 90.950,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 5495976,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 74.864,
"acc@5": 92.322,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7259656,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 75.212,
"acc@5": 92.348,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 7259656,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 77.522,
"acc@5": 93.826,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 9190136,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.040,
"acc@5": 93.440,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 9190136,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 79.668,
"acc@5": 94.922,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 15296552,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.364,
"acc@5": 93.992,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 15296552,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.196,
"acc@5": 95.430,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 39572648,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 79.344,
"acc@5": 94.686,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 39572648,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.682,
"acc@5": 95.678,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 54278536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.058,
"acc@5": 94.944,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 54278536,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.716,
"acc@5": 96.196,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 107811560,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.622,
"acc@5": 95.248,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 107811560,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 83.014,
"acc@5": 96.288,
},
)
DEFAULT = IMAGENET1K_V2
@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_128GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_32GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
return _regnet(params, weights, progress, **kwargs)
from functools import partial
from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"ResNet",
"ResNet18_Weights",
"ResNet34_Weights",
"ResNet50_Weights",
"ResNet101_Weights",
"ResNet152_Weights",
"ResNeXt50_32X4D_Weights",
"ResNeXt101_32X8D_Weights",
"Wide_ResNet50_2_Weights",
"Wide_ResNet101_2_Weights",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"wide_resnet50_2",
"wide_resnet101_2",
]
def _resnet(
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758,
"acc@5": 89.078,
},
)
DEFAULT = IMAGENET1K_V1
class ResNet34_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 21797672,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314,
"acc@5": 91.420,
},
)
DEFAULT = IMAGENET1K_V1
class ResNet50_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130,
"acc@5": 92.862,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
"acc@1": 80.858,
"acc@5": 95.434,
},
)
DEFAULT = IMAGENET1K_V2
class ResNet101_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374,
"acc@5": 93.546,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.886,
"acc@5": 95.780,
},
)
DEFAULT = IMAGENET1K_V2
class ResNet152_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312,
"acc@5": 94.046,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.284,
"acc@5": 96.002,
},
)
DEFAULT = IMAGENET1K_V2
class ResNeXt50_32X4D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618,
"acc@5": 93.698,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.198,
"acc@5": 95.340,
},
)
DEFAULT = IMAGENET1K_V2
class ResNeXt101_32X8D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312,
"acc@5": 94.526,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 82.834,
"acc@5": 96.228,
},
)
DEFAULT = IMAGENET1K_V2
class Wide_ResNet50_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468,
"acc@5": 94.086,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 81.602,
"acc@5": 95.758,
},
)
DEFAULT = IMAGENET1K_V2
class Wide_ResNet101_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848,
"acc@5": 94.284,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.510,
"acc@5": 96.020,
},
)
DEFAULT = IMAGENET1K_V2
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet34_Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet50_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet101_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
weights = ResNet152_Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
def resnext50_32x4d(
*, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
weights = ResNeXt50_32X4D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
def resnext101_32x8d(
*, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
weights = ResNeXt101_32X8D_Weights.verify(weights)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
def wide_resnet50_2(
*, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
weights = Wide_ResNet50_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
def wide_resnet101_2(
*, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
weights = Wide_ResNet101_2_Weights.verify(weights)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
from .fcn import *
from .lraspp import *
from .deeplabv3 import *
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import resnet50, resnet101
from ..resnet import ResNet50_Weights, ResNet101_Weights
__all__ = [
"DeepLabV3",
"DeepLabV3_ResNet50_Weights",
"DeepLabV3_ResNet101_Weights",
"DeepLabV3_MobileNet_V3_Large_Weights",
"deeplabv3_mobilenet_v3_large",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
]
_COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "DeepLabV3",
"publication_year": 2017,
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class DeepLabV3_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 42004074,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
"mIoU": 66.4,
"acc": 92.4,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class DeepLabV3_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 60996202,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
"mIoU": 67.4,
"acc": 92.4,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 11029328,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
"mIoU": 60.3,
"acc": 91.2,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def deeplabv3_resnet50(
*,
weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
weights = DeepLabV3_ResNet50_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
def deeplabv3_resnet101(
*,
weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
weights = DeepLabV3_ResNet101_Weights.verify(weights)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def deeplabv3_mobilenet_v3_large(
*,
weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"]
_COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "FCN",
"publication_year": 2014,
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FCN_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 35322218,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
"mIoU": 60.5,
"acc": 91.4,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class FCN_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 54314346,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
"mIoU": 63.7,
"acc": 91.9,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fcn_resnet50(
*,
weights: Optional[FCN_ResNet50_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
**kwargs: Any,
) -> FCN:
weights = FCN_ResNet50_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
def fcn_resnet101(
*,
weights: Optional[FCN_ResNet101_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101_Weights] = None,
**kwargs: Any,
) -> FCN:
weights = FCN_ResNet101_Weights.verify(weights)
weights_backbone = ResNet101_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import WeightsEnum, Weights
from .._meta import _VOC_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"]
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
"task": "image_semantic_segmentation",
"architecture": "LRASPP",
"publication_year": 2019,
"num_params": 3221538,
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9,
"acc": 91.2,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
@handle_legacy_interface(
weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def lraspp_mobilenet_v3_large(
*,
weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = None,
**kwargs: Any,
) -> LRASPP:
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"ShuffleNetV2",
"ShuffleNet_V2_X0_5_Weights",
"ShuffleNet_V2_X1_0_Weights",
"ShuffleNet_V2_X1_5_Weights",
"ShuffleNet_V2_X2_0_Weights",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
"shufflenet_v2_x2_0",
]
def _shufflenetv2(
weights: Optional[WeightsEnum],
progress: bool,
*args: Any,
**kwargs: Any,
) -> ShuffleNetV2:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ShuffleNetV2(*args, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"publication_year": 2018,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
}
class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 1366792,
"acc@1": 69.362,
"acc@5": 88.316,
},
)
DEFAULT = IMAGENET1K_V1
class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2278604,
"acc@1": 60.552,
"acc@5": 81.746,
},
)
DEFAULT = IMAGENET1K_V1
class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
pass
class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x0_5(
*, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_0(
*, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5(
*, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0(
*, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
_COMMON_META = {
"task": "image_classification",
"architecture": "SqueezeNet",
"publication_year": 2016,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
}
class SqueezeNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"min_size": (21, 21),
"num_params": 1248424,
"acc@1": 58.092,
"acc@5": 80.420,
},
)
DEFAULT = IMAGENET1K_V1
class SqueezeNet1_1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"min_size": (17, 17),
"num_params": 1235496,
"acc@1": 58.178,
"acc@5": 80.624,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
def squeezenet1_0(
*, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
weights = SqueezeNet1_0_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet("1_0", **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
def squeezenet1_1(
*, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
weights = SqueezeNet1_1_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet("1_1", **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"VGG",
"VGG11_Weights",
"VGG11_BN_Weights",
"VGG13_Weights",
"VGG13_BN_Weights",
"VGG16_Weights",
"VGG16_BN_Weights",
"VGG19_Weights",
"VGG19_BN_Weights",
"vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19",
"vgg19_bn",
]
def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "VGG",
"publication_year": 2014,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
}
class VGG11_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132863336,
"acc@1": 69.020,
"acc@5": 88.628,
},
)
DEFAULT = IMAGENET1K_V1
class VGG11_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132868840,
"acc@1": 70.370,
"acc@5": 89.810,
},
)
DEFAULT = IMAGENET1K_V1
class VGG13_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133047848,
"acc@1": 69.928,
"acc@5": 89.246,
},
)
DEFAULT = IMAGENET1K_V1
class VGG13_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133053736,
"acc@1": 71.586,
"acc@5": 90.374,
},
)
DEFAULT = IMAGENET1K_V1
class VGG16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138357544,
"acc@1": 71.592,
"acc@5": 90.382,
},
)
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the
# `classifier` module are filled with nans.
IMAGENET1K_FEATURES = Weights(
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial(
ImageClassificationEval,
crop_size=224,
mean=(0.48235, 0.45882, 0.40784),
std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
),
meta={
**_COMMON_META,
"num_params": 138357544,
"categories": None,
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
"acc@1": float("nan"),
"acc@5": float("nan"),
},
)
DEFAULT = IMAGENET1K_V1
class VGG16_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138365992,
"acc@1": 73.360,
"acc@5": 91.516,
},
)
DEFAULT = IMAGENET1K_V1
class VGG19_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143667240,
"acc@1": 72.376,
"acc@5": 90.876,
},
)
DEFAULT = IMAGENET1K_V1
class VGG19_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143678248,
"acc@1": 74.218,
"acc@5": 91.842,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment