Unverified Commit 11bd2eaa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port Multi-weight support from prototype to main (#5618)



* Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet.

* Porting googlenet

* Porting inception

* Porting mnasnet

* Porting mobilenetv2

* Porting mobilenetv3

* Porting regnet

* Porting resnet

* Porting shufflenetv2

* Porting squeezenet

* Porting vgg

* Porting vit

* Fix docstrings

* Fixing imports

* Adding missing import

* Fix mobilenet imports

* Fix tests

* Fix prototype tests

* Exclude get_weight from models on test

* Fix init files

* Porting googlenet

* Porting inception

* porting mobilenetv2

* porting mobilenetv3

* porting resnet

* porting shufflenetv2

* Fix test and linter

* Fixing docs.

* Porting Detection models (#5617)

* fix inits

* fix docs

* Port faster_rcnn

* Port fcos

* Port keypoint_rcnn

* Port mask_rcnn

* Port retinanet

* Port ssd

* Port ssdlite

* Fix linter

* Fixing tests

* Fixing tests

* Fixing vgg test

* Porting Optical Flow, Segmentation, Video models (#5619)

* Porting raft

* Porting video resnet

* Porting deeplabv3

* Porting fcn and lraspp

* Fixing the tests and linter

* Porting docs, examples, tutorials and galleries (#5620)

* Fix examples, tutorials and gallery

* Update gallery/plot_optical_flow.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Fix import

* Revert hardcoded normalization

* fix uncommitted changes

* Fix bug

* Fix more bugs

* Making resize optional for segmentation

* Fixing preset

* Fix mypy

* Fixing documentation strings

* Fix flake8

* minor refactoring
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Resolve conflict

* Porting model tests (#5622)

* Porting tests

* Remove unnecessary variable

* Fix linter

* Move prototype to extended tests

* Fix download models job

* Update CI on Multiweight branch to use the new weight download approach (#5628)

* port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal

* Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624)
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* pre-download model weights in CI docs build (#5625)

* pre-download model weights in CI docs build

* move changes into template

* change docs image

* Regenerated config.yml
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>

* Porting reference scripts and updating presets (#5629)

* Making _preset.py classes

* Remove support of targets on presets.

* Rewriting the video preset

* Adding tests to check that the bundled transforms are JIT scriptable

* Rename all presets from *Eval to *Inference

* Minor refactoring

* Remove --prototype and --pretrained from reference scripts

* remove  pretained_backbone refs

* Corrections and simplifications

* Fixing bug

* Fixing linter

* Fix flake8

* restore documentation example

* minor fixes

* fix optical flow missing param

* Fixing commands

* Adding weights_backbone support in detection and segmentation

* Updating the commands for InceptionV3

* Setting `weights_backbone` to its fully BC value (#5653)

* Replace default `weights_backbone=None` with its BC values.

* Fixing tests

* Fix linter

* Update docs.

* Update preprocessing on reference scripts.

* Change qat/ptq to their full values.

* Refactoring preprocessing

* Fix video preset

* No initialization on VGG if pretrained

* Fix warning messages for backbone utils.

* Adding star to all preset constructors.

* Fix mypy.
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
parent 375e4ab2
from .mobilenet import *
from .resnet import *
from .googlenet import * from .googlenet import *
from .inception import * from .inception import *
from .mobilenet import *
from .resnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
import warnings import warnings
from typing import Any, Optional from functools import partial
from typing import Any, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights
from .utils import _fuse_modules, _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableGoogLeNet", "googlenet"] __all__ = [
"QuantizableGoogLeNet",
quant_model_urls = { "GoogLeNet_QuantizedWeights",
# fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch "googlenet",
"googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", ]
}
class QuantizableBasicConv2d(BasicConv2d): class QuantizableBasicConv2d(BasicConv2d):
...@@ -103,8 +106,41 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -103,8 +106,41 @@ class QuantizableGoogLeNet(GoogLeNet):
m.fuse_model(is_qat) m.fuse_model(is_qat)
class GoogLeNet_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
"task": "image_classification",
"architecture": "GoogLeNet",
"publication_year": 2014,
"num_params": 6624904,
"size": (224, 224),
"min_size": (15, 15),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": GoogLeNet_Weights.IMAGENET1K_V1,
"acc@1": 69.826,
"acc@5": 89.404,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else GoogLeNet_Weights.IMAGENET1K_V1,
)
)
def googlenet( def googlenet(
pretrained: bool = False, *,
weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -117,49 +153,38 @@ def googlenet( ...@@ -117,49 +153,38 @@ def googlenet(
GPU inference is not yet supported GPU inference is not yet supported
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: True if ``pretrained=True``, else False.
""" """
if pretrained: weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True _ovewrite_named_param(kwargs, "transform_input", True)
if "aux_logits" not in kwargs: _ovewrite_named_param(kwargs, "aux_logits", True)
kwargs["aux_logits"] = False _ovewrite_named_param(kwargs, "init_weights", False)
if kwargs["aux_logits"]: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
warnings.warn( if "backend" in weights.meta:
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
) backend = kwargs.pop("backend", "fbgemm")
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs) model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["googlenet_" + backend]
else:
model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.aux1 = None # type: ignore[assignment] model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment]
else:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
return model return model
import warnings import warnings
from typing import Any, List, Optional from functools import partial
from typing import Any, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torchvision.models import inception as inception_module from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights
from ..._internally_replaced_utils import load_state_dict_from_url from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from .utils import _fuse_modules, _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = [ __all__ = [
"QuantizableInception3", "QuantizableInception3",
"Inception_V3_QuantizedWeights",
"inception_v3", "inception_v3",
] ]
quant_model_urls = {
# fp32 weights ported from TensorFlow, quantized in PyTorch
"inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
}
class QuantizableBasicConv2d(inception_module.BasicConv2d): class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -173,8 +172,41 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -173,8 +172,41 @@ class QuantizableInception3(inception_module.Inception3):
m.fuse_model(is_qat) m.fuse_model(is_qat)
class Inception_V3_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageClassification, crop_size=299, resize_size=342),
meta={
"task": "image_classification",
"architecture": "InceptionV3",
"publication_year": 2015,
"num_params": 27161264,
"size": (299, 299),
"min_size": (75, 75),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
"unquantized": Inception_V3_Weights.IMAGENET1K_V1,
"acc@1": 77.176,
"acc@5": 93.354,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else Inception_V3_Weights.IMAGENET1K_V1,
)
)
def inception_v3( def inception_v3(
pretrained: bool = False, *,
weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -191,48 +223,35 @@ def inception_v3( ...@@ -191,48 +223,35 @@ def inception_v3(
GPU inference is not yet supported GPU inference is not yet supported
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (Inception_V3_QuantizedWeights or Inception_V3_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: True if ``pretrained=True``, else False.
""" """
if pretrained: weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs: if "transform_input" not in kwargs:
kwargs["transform_input"] = True _ovewrite_named_param(kwargs, "transform_input", True)
if "aux_logits" in kwargs: _ovewrite_named_param(kwargs, "aux_logits", True)
original_aux_logits = kwargs["aux_logits"] _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
kwargs["aux_logits"] = True if "backend" in weights.meta:
else: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
original_aux_logits = False backend = kwargs.pop("backend", "fbgemm")
model = QuantizableInception3(**kwargs) model = QuantizableInception3(**kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model_url = quant_model_urls["inception_v3_google_" + backend]
else:
model_url = inception_module.model_urls["inception_v3_google"]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict) if weights is not None:
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.get_state_dict(progress=progress))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
if not quantize:
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model return model
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all
__all__ = mv2_all + mv3_all __all__ = mv2_all + mv3_all
from typing import Any, Optional from functools import partial
from typing import Any, Optional, Union
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
from torch.ao.quantization import QuantStub, DeQuantStub from torch.ao.quantization import QuantStub, DeQuantStub
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import Conv2dNormActivation from ...ops.misc import Conv2dNormActivation
from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from .utils import _fuse_modules, _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] __all__ = [
"QuantizableMobileNetV2",
quant_model_urls = { "MobileNet_V2_QuantizedWeights",
"mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" "mobilenet_v2",
} ]
class QuantizableInvertedResidual(InvertedResidual): class QuantizableInvertedResidual(InvertedResidual):
...@@ -60,8 +64,41 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -60,8 +64,41 @@ class QuantizableMobileNetV2(MobileNetV2):
m.fuse_model(is_qat) m.fuse_model(is_qat)
class MobileNet_V2_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"publication_year": 2018,
"num_params": 3504872,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "Quantization Aware Training",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNet_V2_Weights.IMAGENET1K_V1,
"acc@1": 71.658,
"acc@5": 90.150,
},
)
DEFAULT = IMAGENET1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V2_Weights.IMAGENET1K_V1,
)
)
def mobilenet_v2( def mobilenet_v2(
pretrained: bool = False, *,
weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -76,27 +113,25 @@ def mobilenet_v2( ...@@ -76,27 +113,25 @@ def mobilenet_v2(
GPU inference is not yet supported GPU inference is not yet supported
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained
progress (bool): If True, displays a progress bar of the download to stderr weights for the model
quantize(bool): If True, returns a quantized model, else returns a float model progress (bool): If True, displays a progress bar of the download to stderr
quantize(bool): If True, returns a quantized model, else returns a float model
""" """
weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend
backend = "qnnpack"
quantize_model(model, backend) quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls["mobilenet_v2_" + backend]
else:
model_url = model_urls["mobilenet_v2"]
state_dict = load_state_dict_from_url(model_url, progress=progress) if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(state_dict)
return model return model
from typing import Any, List, Optional from functools import partial
from typing import Any, List, Optional, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub from torch.ao.quantization import QuantStub, DeQuantStub
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import (
InvertedResidual,
InvertedResidualConfig,
MobileNetV3,
_mobilenet_v3_conf,
MobileNet_V3_Large_Weights,
)
from .utils import _fuse_modules, _replace_relu from .utils import _fuse_modules, _replace_relu
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] __all__ = [
"QuantizableMobileNetV3",
quant_model_urls = { "MobileNet_V3_Large_QuantizedWeights",
"mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", "mobilenet_v3_large",
} ]
class QuantizableSqueezeExcitation(SqueezeExcitation): class QuantizableSqueezeExcitation(SqueezeExcitation):
...@@ -112,47 +122,73 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -112,47 +122,73 @@ class QuantizableMobileNetV3(MobileNetV3):
m.fuse_model(is_qat) m.fuse_model(is_qat)
def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError(f"No checkpoint is available for {arch}")
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
def _mobilenet_v3_model( def _mobilenet_v3_model(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig], inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int, last_channel: int,
pretrained: bool, weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV3: ) -> QuantizableMobileNetV3:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
backend = "qnnpack"
model.fuse_model(is_qat=True) model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
if pretrained: if weights is not None:
_load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) model.load_state_dict(weights.get_state_dict(progress=progress))
if quantize:
torch.ao.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
model.eval() model.eval()
else:
if pretrained:
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV3",
"publication_year": 2019,
"num_params": 5483032,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "Quantization Aware Training",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
"acc@1": 73.004,
"acc@5": 90.858,
},
)
DEFAULT = IMAGENET1K_QNNPACK_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1
if kwargs.get("quantize", False)
else MobileNet_V3_Large_Weights.IMAGENET1K_V1,
)
)
def mobilenet_v3_large( def mobilenet_v3_large(
pretrained: bool = False, *,
weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -166,10 +202,12 @@ def mobilenet_v3_large( ...@@ -166,10 +202,12 @@ def mobilenet_v3_large(
GPU inference is not yet supported GPU inference is not yet supported
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. weights (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained
progress (bool): If True, displays a progress bar of the download to stderr weights for the model
quantize (bool): If True, returns a quantized model, else returns a float model progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, returns a quantized model, else returns a float model
""" """
arch = "mobilenet_v3_large" weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
from functools import partial
from typing import Any, Type, Union, List, Optional from typing import Any, Type, Union, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls from torchvision.models.resnet import (
Bottleneck,
from ..._internally_replaced_utils import load_state_dict_from_url BasicBlock,
ResNet,
ResNet18_Weights,
ResNet50_Weights,
ResNeXt101_32X8D_Weights,
)
from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from .utils import _fuse_modules, _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
quant_model_urls = { __all__ = [
"resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", "QuantizableResNet",
"resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", "ResNet18_QuantizedWeights",
"resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", "ResNet50_QuantizedWeights",
} "ResNeXt101_32X8D_QuantizedWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
]
class QuantizableBasicBlock(BasicBlock): class QuantizableBasicBlock(BasicBlock):
...@@ -109,38 +122,130 @@ class QuantizableResNet(ResNet): ...@@ -109,38 +122,130 @@ class QuantizableResNet(ResNet):
def _resnet( def _resnet(
arch: str,
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
layers: List[int], layers: List[int],
pretrained: bool, weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableResNet(block, layers, **kwargs) model = QuantizableResNet(block, layers, **kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained: if weights is not None:
if quantize: model.load_state_dict(weights.get_state_dict(progress=progress))
model_url = quant_model_urls[arch + "_" + backend]
else:
model_url = model_urls[arch]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
_COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
class ResNet18_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 11689512,
"unquantized": ResNet18_Weights.IMAGENET1K_V1,
"acc@1": 69.494,
"acc@5": 88.882,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
class ResNet50_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V1,
"acc@1": 75.920,
"acc@5": 92.814,
},
)
IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V2,
"acc@1": 80.282,
"acc@5": 94.976,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V2
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
"acc@1": 78.986,
"acc@5": 94.480,
},
)
IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
"acc@1": 82.574,
"acc@5": 96.132,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V2
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet18_Weights.IMAGENET1K_V1,
)
)
def resnet18( def resnet18(
pretrained: bool = False, *,
weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -149,33 +254,56 @@ def resnet18( ...@@ -149,33 +254,56 @@ def resnet18(
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNet50_Weights.IMAGENET1K_V1,
)
)
def resnet50( def resnet50(
pretrained: bool = False, *,
weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
)
)
def resnext101_32x8d( def resnext101_32x8d(
pretrained: bool = False, *,
weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -184,10 +312,13 @@ def resnext101_32x8d( ...@@ -184,10 +312,13 @@ def resnext101_32x8d(
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
kwargs["groups"] = 32 weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
kwargs["width_per_group"] = 8
return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
from typing import Any, Optional from functools import partial
from typing import Any, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torchvision.models import shufflenetv2 from torchvision.models import shufflenetv2
from ..._internally_replaced_utils import load_state_dict_from_url from ...transforms._presets import ImageClassification, InterpolationMode
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
from .utils import _fuse_modules, _replace_relu, quantize_model from .utils import _fuse_modules, _replace_relu, quantize_model
__all__ = [ __all__ = [
"QuantizableShuffleNetV2", "QuantizableShuffleNetV2",
"ShuffleNet_V2_X0_5_QuantizedWeights",
"ShuffleNet_V2_X1_0_QuantizedWeights",
"shufflenet_v2_x0_5", "shufflenet_v2_x0_5",
"shufflenet_v2_x1_0", "shufflenet_v2_x1_0",
] ]
quant_model_urls = {
"shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
"shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
}
class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
...@@ -73,39 +76,86 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -73,39 +76,86 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
def _shufflenetv2( def _shufflenetv2(
arch: str, stages_repeats: List[int],
pretrained: bool, stages_out_channels: List[int],
*,
weights: Optional[WeightsEnum],
progress: bool, progress: bool,
quantize: bool, quantize: bool,
*args: Any,
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableShuffleNetV2(*args, **kwargs) model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend
backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained: if weights is not None:
model_url: Optional[str] = None model.load_state_dict(weights.get_state_dict(progress=progress))
if quantize:
model_url = quant_model_urls[arch + "_" + backend]
else:
model_url = shufflenetv2.model_urls[arch]
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
_COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"publication_year": 2018,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 1366792,
"unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
"acc@1": 57.972,
"acc@5": 79.780,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2278604,
"unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
"acc@1": 68.360,
"acc@5": 87.582,
},
)
DEFAULT = IMAGENET1K_FBGEMM_V1
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
)
)
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
pretrained: bool = False, *,
weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -116,17 +166,28 @@ def shufflenet_v2_x0_5( ...@@ -116,17 +166,28 @@ def shufflenet_v2_x0_5(
<https://arxiv.org/abs/1807.11164>`_. <https://arxiv.org/abs/1807.11164>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
return _shufflenetv2( return _shufflenetv2(
"shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
) )
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1
if kwargs.get("quantize", False)
else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
)
)
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
pretrained: bool = False, *,
weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
progress: bool = True, progress: bool = True,
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -137,10 +198,12 @@ def shufflenet_v2_x1_0( ...@@ -137,10 +198,12 @@ def shufflenet_v2_x1_0(
<https://arxiv.org/abs/1807.11164>`_. <https://arxiv.org/abs/1807.11164>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained
weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
return _shufflenetv2( return _shufflenetv2(
"shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
) )
# Modified from
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/anynet.py
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
import math import math
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
...@@ -11,14 +6,31 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -11,14 +6,31 @@ from typing import Any, Callable, List, Optional, Tuple
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import _make_divisible from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
__all__ = [ __all__ = [
"RegNet", "RegNet",
"RegNet_Y_400MF_Weights",
"RegNet_Y_800MF_Weights",
"RegNet_Y_1_6GF_Weights",
"RegNet_Y_3_2GF_Weights",
"RegNet_Y_8GF_Weights",
"RegNet_Y_16GF_Weights",
"RegNet_Y_32GF_Weights",
"RegNet_Y_128GF_Weights",
"RegNet_X_400MF_Weights",
"RegNet_X_800MF_Weights",
"RegNet_X_1_6GF_Weights",
"RegNet_X_3_2GF_Weights",
"RegNet_X_8GF_Weights",
"RegNet_X_16GF_Weights",
"RegNet_X_32GF_Weights",
"regnet_y_400mf", "regnet_y_400mf",
"regnet_y_800mf", "regnet_y_800mf",
"regnet_y_1_6gf", "regnet_y_1_6gf",
...@@ -37,24 +49,6 @@ __all__ = [ ...@@ -37,24 +49,6 @@ __all__ = [
] ]
model_urls = {
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
}
class SimpleStemIN(Conv2dNormActivation): class SimpleStemIN(Conv2dNormActivation):
"""Simple stem for ImageNet: 3x3, BN, ReLU.""" """Simple stem for ImageNet: 3x3, BN, ReLU."""
...@@ -390,219 +384,652 @@ class RegNet(nn.Module): ...@@ -390,219 +384,652 @@ class RegNet(nn.Module):
return x return x
def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet: def _regnet(
block_params: BlockParams,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> RegNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
model = RegNet(block_params, norm_layer=norm_layer, **kwargs) model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if pretrained:
if arch not in model_urls: if weights is not None:
raise ValueError(f"No checkpoint is available for model type {arch}") model.load_state_dict(weights.get_state_dict(progress=progress))
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model return model
def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: _COMMON_META = {
"task": "image_classification",
"architecture": "RegNet",
"publication_year": 2020,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class RegNet_Y_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 4344144,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 74.046,
"acc@5": 91.716,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 4344144,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 75.804,
"acc@5": 92.742,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 6432512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 76.420,
"acc@5": 93.136,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 6432512,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 78.828,
"acc@5": 94.502,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11202430,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.950,
"acc@5": 93.966,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 11202430,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 80.876,
"acc@5": 95.444,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 19436338,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.948,
"acc@5": 94.576,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 19436338,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.982,
"acc@5": 95.972,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 39381472,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.032,
"acc@5": 95.048,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 39381472,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.828,
"acc@5": 96.330,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 83590140,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.424,
"acc@5": 95.240,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 83590140,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.886,
"acc@5": 96.328,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 145046770,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.878,
"acc@5": 95.340,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 145046770,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 83.368,
"acc@5": 96.498,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum):
# weights are not available yet.
pass
class RegNet_X_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 5495976,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 72.834,
"acc@5": 90.950,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 5495976,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 74.864,
"acc@5": 92.322,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7259656,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 75.212,
"acc@5": 92.348,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 7259656,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 77.522,
"acc@5": 93.826,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 9190136,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.040,
"acc@5": 93.440,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 9190136,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 79.668,
"acc@5": 94.922,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 15296552,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.364,
"acc@5": 93.992,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 15296552,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.196,
"acc@5": 95.430,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 39572648,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 79.344,
"acc@5": 94.686,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 39572648,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.682,
"acc@5": 95.678,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 54278536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.058,
"acc@5": 94.944,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 54278536,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.716,
"acc@5": 96.196,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_X_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 107811560,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.622,
"acc@5": 95.248,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 107811560,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 83.014,
"acc@5": 96.288,
},
)
DEFAULT = IMAGENET1K_V2
@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_400MF architecture from Constructs a RegNetY_400MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_400MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_800MF architecture from Constructs a RegNetY_800MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_800MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_1.6GF architecture from Constructs a RegNetY_1.6GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_1_6GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
) )
return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_3.2GF architecture from Constructs a RegNetY_3.2GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_3_2GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
) )
return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_8GF architecture from Constructs a RegNetY_8GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_8GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_8GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
) )
return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_16GF architecture from Constructs a RegNetY_16GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_16GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_16GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
) )
return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_32GF architecture from Constructs a RegNetY_32GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_Y_32GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_32GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
) )
return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetY_128GF architecture from Constructs a RegNetY_128GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
NOTE: Pretrained weights are not available for this model. NOTE: Pretrained weights are not available for this model.
Args:
weights (RegNet_Y_128GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_Y_128GF_Weights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
) )
return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_400MF architecture from Constructs a RegNetX_400MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_400MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_X_400MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_800MF architecture from Constructs a RegNetX_800MF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_800MF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_X_800MF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_1.6GF architecture from Constructs a RegNetX_1.6GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_X_1_6GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_3.2GF architecture from Constructs a RegNetX_3.2GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_X_3_2GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_8GF architecture from Constructs a RegNetX_8GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_X_8GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_16GF architecture from Constructs a RegNetX_16GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = RegNet_X_16GF_Weights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: @handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_32GF architecture from Constructs a RegNetX_32GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_. `"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) weights = RegNet_X_32GF_Weights.verify(weights)
return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
# TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF return _regnet(params, weights, progress, **kwargs)
from functools import partial
from typing import Type, Any, Callable, Union, List, Optional from typing import Type, Any, Callable, Union, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
"ResNet", "ResNet",
"ResNet18_Weights",
"ResNet34_Weights",
"ResNet50_Weights",
"ResNet101_Weights",
"ResNet152_Weights",
"ResNeXt50_32X4D_Weights",
"ResNeXt101_32X8D_Weights",
"Wide_ResNet50_2_Weights",
"Wide_ResNet101_2_Weights",
"resnet18", "resnet18",
"resnet34", "resnet34",
"resnet50", "resnet50",
...@@ -22,19 +35,6 @@ __all__ = [ ...@@ -22,19 +35,6 @@ __all__ = [
] ]
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d( return nn.Conv2d(
...@@ -284,102 +284,386 @@ class ResNet(nn.Module): ...@@ -284,102 +284,386 @@ class ResNet(nn.Module):
def _resnet( def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]], block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int], layers: List[int],
pretrained: bool, weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> ResNet: ) -> ResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) if weights is not None:
model.load_state_dict(state_dict) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: _COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758,
"acc@5": 89.078,
},
)
DEFAULT = IMAGENET1K_V1
class ResNet34_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 21797672,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314,
"acc@5": 91.420,
},
)
DEFAULT = IMAGENET1K_V1
class ResNet50_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130,
"acc@5": 92.862,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
"acc@1": 80.858,
"acc@5": 95.434,
},
)
DEFAULT = IMAGENET1K_V2
class ResNet101_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374,
"acc@5": 93.546,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.886,
"acc@5": 95.780,
},
)
DEFAULT = IMAGENET1K_V2
class ResNet152_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312,
"acc@5": 94.046,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"publication_year": 2015,
"num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.284,
"acc@5": 96.002,
},
)
DEFAULT = IMAGENET1K_V2
class ResNeXt50_32X4D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618,
"acc@5": 93.698,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.198,
"acc@5": 95.340,
},
)
DEFAULT = IMAGENET1K_V2
class ResNeXt101_32X8D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312,
"acc@5": 94.526,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"publication_year": 2016,
"num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 82.834,
"acc@5": 96.228,
},
)
DEFAULT = IMAGENET1K_V2
class Wide_ResNet50_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468,
"acc@5": 94.086,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 81.602,
"acc@5": 95.758,
},
)
DEFAULT = IMAGENET1K_V2
class Wide_ResNet101_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848,
"acc@5": 94.284,
},
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"publication_year": 2016,
"num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.510,
"acc@5": 96.020,
},
)
DEFAULT = IMAGENET1K_V2
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet34_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) weights = ResNet34_Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet50_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) weights = ResNet50_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet101_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) weights = ResNet101_Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNet152_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) weights = ResNet152_Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
def resnext50_32x4d(
*, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
r"""ResNeXt-50 32x4d model from r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNeXt50_32X4D_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs["groups"] = 32 weights = ResNeXt50_32X4D_Weights.verify(weights)
kwargs["width_per_group"] = 4
return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
def resnext101_32x8d(
*, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
r"""ResNeXt-101 32x8d model from r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ResNeXt101_32X8D_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs["groups"] = 32 weights = ResNeXt101_32X8D_Weights.verify(weights)
kwargs["width_per_group"] = 8
return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
def wide_resnet50_2(
*, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
r"""Wide ResNet-50-2 model from r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
...@@ -389,14 +673,19 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -389,14 +673,19 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A
channels, and in Wide ResNet-50-2 has 2048-1024-2048. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (Wide_ResNet50_2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs["width_per_group"] = 64 * 2 weights = Wide_ResNet50_2_Weights.verify(weights)
return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
def wide_resnet101_2(
*, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
r"""Wide ResNet-101-2 model from r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_. `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
...@@ -406,8 +695,10 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -406,8 +695,10 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs:
channels, and in Wide ResNet-50-2 has 2048-1024-2048. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (Wide_ResNet101_2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs["width_per_group"] = 64 * 2 weights = Wide_ResNet101_2_Weights.verify(weights)
return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
from .fcn import *
from .deeplabv3 import * from .deeplabv3 import *
from .fcn import *
from .lraspp import * from .lraspp import *
...@@ -4,7 +4,6 @@ from typing import Optional, Dict ...@@ -4,7 +4,6 @@ from typing import Optional, Dict
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
...@@ -36,10 +35,3 @@ class _SimpleSegmentationModel(nn.Module): ...@@ -36,10 +35,3 @@ class _SimpleSegmentationModel(nn.Module):
result["aux"] = x result["aux"] = x
return result return result
def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError(f"No checkpoint is available for {arch}")
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
from typing import List, Optional from functools import partial
from typing import Any, List, Optional
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .. import mobilenetv3 from ...transforms._presets import SemanticSegmentation, InterpolationMode
from .. import resnet from .._api import WeightsEnum, Weights
from .._utils import IntermediateLayerGetter from .._meta import _VOC_CATEGORIES
from ._utils import _SimpleSegmentationModel, _load_weights from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights
from ._utils import _SimpleSegmentationModel
from .fcn import FCNHead from .fcn import FCNHead
__all__ = [ __all__ = [
"DeepLabV3", "DeepLabV3",
"DeepLabV3_ResNet50_Weights",
"DeepLabV3_ResNet101_Weights",
"DeepLabV3_MobileNet_V3_Large_Weights",
"deeplabv3_mobilenet_v3_large",
"deeplabv3_resnet50", "deeplabv3_resnet50",
"deeplabv3_resnet101", "deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
] ]
model_urls = {
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
}
class DeepLabV3(_SimpleSegmentationModel): class DeepLabV3(_SimpleSegmentationModel):
""" """
Implements DeepLabV3 model from Implements DeepLabV3 model from
...@@ -114,7 +114,7 @@ class ASPP(nn.Module): ...@@ -114,7 +114,7 @@ class ASPP(nn.Module):
def _deeplabv3_resnet( def _deeplabv3_resnet(
backbone: resnet.ResNet, backbone: ResNet,
num_classes: int, num_classes: int,
aux: Optional[bool], aux: Optional[bool],
) -> DeepLabV3: ) -> DeepLabV3:
...@@ -128,8 +128,62 @@ def _deeplabv3_resnet( ...@@ -128,8 +128,62 @@ def _deeplabv3_resnet(
return DeepLabV3(backbone, classifier, aux_classifier) return DeepLabV3(backbone, classifier, aux_classifier)
_COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "DeepLabV3",
"publication_year": 2017,
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class DeepLabV3_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 42004074,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
"mIoU": 66.4,
"acc": 92.4,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class DeepLabV3_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 60996202,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
"mIoU": 67.4,
"acc": 92.4,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 11029328,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
"mIoU": 60.3,
"acc": 91.2,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
def _deeplabv3_mobilenetv3( def _deeplabv3_mobilenetv3(
backbone: mobilenetv3.MobileNetV3, backbone: MobileNetV3,
num_classes: int, num_classes: int,
aux: Optional[bool], aux: Optional[bool],
) -> DeepLabV3: ) -> DeepLabV3:
...@@ -151,91 +205,124 @@ def _deeplabv3_mobilenetv3( ...@@ -151,91 +205,124 @@ def _deeplabv3_mobilenetv3(
return DeepLabV3(backbone, classifier, aux_classifier) return DeepLabV3(backbone, classifier, aux_classifier)
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def deeplabv3_resnet50( def deeplabv3_resnet50(
pretrained: bool = False, *,
weights: Optional[DeepLabV3_ResNet50_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True, weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone. """Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which weights (DeepLabV3_ResNet50_Weights, optional): The pretrained weights for the model
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained. weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
""" """
if pretrained: weights = DeepLabV3_ResNet50_Weights.verify(weights)
aux_loss = True weights_backbone = ResNet50_Weights.verify(weights_backbone)
pretrained_backbone = False
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if pretrained: if weights is not None:
arch = "deeplabv3_resnet50_coco" model.load_state_dict(weights.get_state_dict(progress=progress))
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
def deeplabv3_resnet101( def deeplabv3_resnet101(
pretrained: bool = False, *,
weights: Optional[DeepLabV3_ResNet101_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True, weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone. """Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which weights (DeepLabV3_ResNet101_Weights, optional): The pretrained weights for the model
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): The number of classes num_classes (int): The number of classes
aux_loss (bool, optional): If True, include an auxiliary classifier aux_loss (bool, optional): If True, include an auxiliary classifier
pretrained_backbone (bool): If True, the backbone will be pre-trained. weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone
""" """
if pretrained: weights = DeepLabV3_ResNet101_Weights.verify(weights)
aux_loss = True weights_backbone = ResNet101_Weights.verify(weights_backbone)
pretrained_backbone = False
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if pretrained: if weights is not None:
arch = "deeplabv3_resnet101_coco" model.load_state_dict(weights.get_state_dict(progress=progress))
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
@handle_legacy_interface(
weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def deeplabv3_mobilenet_v3_large( def deeplabv3_mobilenet_v3_large(
pretrained: bool = False, *,
weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True, weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which weights (DeepLabV3_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained. weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
""" """
if pretrained: weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
aux_loss = True weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if pretrained: if weights is not None:
arch = "deeplabv3_mobilenet_v3_large_coco" model.load_state_dict(weights.get_state_dict(progress=progress))
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
from typing import Optional from functools import partial
from typing import Any, Optional
from torch import nn from torch import nn
from .. import resnet from ...transforms._presets import SemanticSegmentation, InterpolationMode
from .._utils import IntermediateLayerGetter from .._api import WeightsEnum, Weights
from ._utils import _SimpleSegmentationModel, _load_weights from .._meta import _VOC_CATEGORIES
from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101
from ._utils import _SimpleSegmentationModel
__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"] __all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"]
model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
}
class FCN(_SimpleSegmentationModel): class FCN(_SimpleSegmentationModel):
...@@ -49,8 +47,47 @@ class FCNHead(nn.Sequential): ...@@ -49,8 +47,47 @@ class FCNHead(nn.Sequential):
super().__init__(*layers) super().__init__(*layers)
_COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "FCN",
"publication_year": 2014,
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FCN_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 35322218,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
"mIoU": 60.5,
"acc": 91.4,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
class FCN_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
**_COMMON_META,
"num_params": 54314346,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
"mIoU": 63.7,
"acc": 91.9,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
def _fcn_resnet( def _fcn_resnet(
backbone: resnet.ResNet, backbone: ResNet,
num_classes: int, num_classes: int,
aux: Optional[bool], aux: Optional[bool],
) -> FCN: ) -> FCN:
...@@ -64,61 +101,83 @@ def _fcn_resnet( ...@@ -64,61 +101,83 @@ def _fcn_resnet(
return FCN(backbone, classifier, aux_classifier) return FCN(backbone, classifier, aux_classifier)
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fcn_resnet50( def fcn_resnet50(
pretrained: bool = False, *,
weights: Optional[FCN_ResNet50_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True, weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> FCN: ) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which weights (FCN_ResNet50_Weights, optional): The pretrained weights for the model
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained. weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
""" """
if pretrained: weights = FCN_ResNet50_Weights.verify(weights)
aux_loss = True weights_backbone = ResNet50_Weights.verify(weights_backbone)
pretrained_backbone = False
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained: if weights is not None:
arch = "fcn_resnet50_coco" model.load_state_dict(weights.get_state_dict(progress=progress))
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
@handle_legacy_interface(
weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
)
def fcn_resnet101( def fcn_resnet101(
pretrained: bool = False, *,
weights: Optional[FCN_ResNet101_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True, weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> FCN: ) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which weights (FCN_ResNet101_Weights, optional): The pretrained weights for the model
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained. weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone
""" """
if pretrained: weights = FCN_ResNet101_Weights.verify(weights)
aux_loss = True weights_backbone = ResNet101_Weights.verify(weights_backbone)
pretrained_backbone = False
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained: if weights is not None:
arch = "fcn_resnet101_coco" model.load_state_dict(weights.get_state_dict(progress=progress))
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict from functools import partial
from typing import Any, Dict, Optional
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from ...transforms._presets import SemanticSegmentation, InterpolationMode
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .. import mobilenetv3 from .._api import WeightsEnum, Weights
from .._utils import IntermediateLayerGetter from .._meta import _VOC_CATEGORIES
from ._utils import _load_weights from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"] __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"]
model_urls = {
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}
class LRASPP(nn.Module): class LRASPP(nn.Module):
...@@ -30,7 +28,7 @@ class LRASPP(nn.Module): ...@@ -30,7 +28,7 @@ class LRASPP(nn.Module):
"high" for the high level feature map and "low" for the low level feature map. "high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features. low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features. high_channels (int): the number of channels of the high level features.
num_classes (int): number of output classes of the model (including the background). num_classes (int, optional): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations. inter_channels (int, optional): the number of channels for intermediate computations.
""" """
...@@ -81,7 +79,7 @@ class LRASPPHead(nn.Module): ...@@ -81,7 +79,7 @@ class LRASPPHead(nn.Module):
return self.low_classifier(low) + self.high_classifier(x) return self.low_classifier(low) + self.high_classifier(x)
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP: def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP:
backbone = backbone.features backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
...@@ -95,31 +93,61 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> ...@@ -95,31 +93,61 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
return LRASPP(backbone, low_channels, high_channels, num_classes) return LRASPP(backbone, low_channels, high_channels, num_classes)
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
"task": "image_semantic_segmentation",
"architecture": "LRASPP",
"publication_year": 2019,
"num_params": 3221538,
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9,
"acc": 91.2,
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
@handle_legacy_interface(
weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def lraspp_mobilenet_v3_large( def lraspp_mobilenet_v3_large(
pretrained: bool = False, *,
weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
pretrained_backbone: bool = True, weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
**kwargs: Any, **kwargs: Any,
) -> LRASPP: ) -> LRASPP:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which weights (LRASPP_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background) num_classes (int, optional): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, the backbone will be pre-trained. weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone
""" """
if kwargs.pop("aux_loss", False): if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss") raise NotImplementedError("This model does not use auxiliary loss")
if pretrained:
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes) model = _lraspp_mobilenetv3(backbone, num_classes)
if pretrained: if weights is not None:
arch = "lraspp_mobilenet_v3_large_coco" model.load_state_dict(weights.get_state_dict(progress=progress))
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model return model
from typing import Callable, Any, List from functools import partial
from typing import Callable, Any, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] __all__ = [
"ShuffleNetV2",
model_urls = { "ShuffleNet_V2_X0_5_Weights",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", "ShuffleNet_V2_X1_0_Weights",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", "ShuffleNet_V2_X1_5_Weights",
"shufflenetv2_x1.5": None, "ShuffleNet_V2_X2_0_Weights",
"shufflenetv2_x2.0": None, "shufflenet_v2_x0_5",
} "shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
"shufflenet_v2_x2_0",
]
def channel_shuffle(x: Tensor, groups: int) -> Tensor: def channel_shuffle(x: Tensor, groups: int) -> Tensor:
...@@ -159,67 +166,138 @@ class ShuffleNetV2(nn.Module): ...@@ -159,67 +166,138 @@ class ShuffleNetV2(nn.Module):
return self._forward_impl(x) return self._forward_impl(x)
def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: def _shufflenetv2(
weights: Optional[WeightsEnum],
progress: bool,
*args: Any,
**kwargs: Any,
) -> ShuffleNetV2:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
if pretrained: if weights is not None:
model_url = model_urls[arch] model.load_state_dict(weights.get_state_dict(progress=progress))
if model_url is None:
raise ValueError(f"No checkpoint is available for model type {arch}")
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: _COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"publication_year": 2018,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
}
class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 1366792,
"acc@1": 69.362,
"acc@5": 88.316,
},
)
DEFAULT = IMAGENET1K_V1
class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2278604,
"acc@1": 60.552,
"acc@5": 81.746,
},
)
DEFAULT = IMAGENET1K_V1
class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
pass
class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
pass
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x0_5(
*, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 0.5x output channels, as described in Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_. <https://arxiv.org/abs/1807.11164>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_0(
*, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 1.0x output channels, as described in Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_. <https://arxiv.org/abs/1807.11164>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) weights = ShuffleNet_V2_X1_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5(
*, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 1.5x output channels, as described in Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_. <https://arxiv.org/abs/1807.11164>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0(
*, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 2.0x output channels, as described in Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_. <https://arxiv.org/abs/1807.11164>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
from typing import Any from functools import partial
from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
from .._internally_replaced_utils import load_state_dict_from_url from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"]
model_urls = { __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
}
class Fire(nn.Module): class Fire(nn.Module):
...@@ -97,29 +97,85 @@ class SqueezeNet(nn.Module): ...@@ -97,29 +97,85 @@ class SqueezeNet(nn.Module):
return torch.flatten(x, 1) return torch.flatten(x, 1)
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: def _squeezenet(
version: str,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> SqueezeNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet(version, **kwargs) model = SqueezeNet(version, **kwargs)
if pretrained:
arch = "squeezenet" + version if weights is not None:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(state_dict)
return model return model
def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: _COMMON_META = {
"task": "image_classification",
"architecture": "SqueezeNet",
"publication_year": 2016,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
}
class SqueezeNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"min_size": (21, 21),
"num_params": 1248424,
"acc@1": 58.092,
"acc@5": 80.420,
},
)
DEFAULT = IMAGENET1K_V1
class SqueezeNet1_1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"min_size": (17, 17),
"num_params": 1235496,
"acc@1": 58.178,
"acc@5": 80.624,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
def squeezenet1_0(
*, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size" accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper. <https://arxiv.org/abs/1602.07360>`_ paper.
The required minimum input size of the model is 21x21. The required minimum input size of the model is 21x21.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (SqueezeNet1_0_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _squeezenet("1_0", pretrained, progress, **kwargs) weights = SqueezeNet1_0_Weights.verify(weights)
return _squeezenet("1_0", weights, progress, **kwargs)
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: @handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
def squeezenet1_1(
*, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
) -> SqueezeNet:
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_. <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
...@@ -127,7 +183,8 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -127,7 +183,8 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any
The required minimum input size of the model is 17x17. The required minimum input size of the model is 17x17.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (SqueezeNet1_1_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _squeezenet("1_1", pretrained, progress, **kwargs) weights = SqueezeNet1_1_Weights.verify(weights)
return _squeezenet("1_1", weights, progress, **kwargs)
from typing import Union, List, Dict, Any, cast from functools import partial
from typing import Union, List, Dict, Any, Optional, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
"VGG", "VGG",
"VGG11_Weights",
"VGG11_BN_Weights",
"VGG13_Weights",
"VGG13_BN_Weights",
"VGG16_Weights",
"VGG16_BN_Weights",
"VGG19_Weights",
"VGG19_BN_Weights",
"vgg11", "vgg11",
"vgg11_bn", "vgg11_bn",
"vgg13", "vgg13",
"vgg13_bn", "vgg13_bn",
"vgg16", "vgg16",
"vgg16_bn", "vgg16_bn",
"vgg19_bn",
"vgg19", "vgg19",
"vgg19_bn",
] ]
model_urls = {
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
}
class VGG(nn.Module): class VGG(nn.Module):
def __init__( def __init__(
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
...@@ -95,107 +95,276 @@ cfgs: Dict[str, List[Union[str, int]]] = { ...@@ -95,107 +95,276 @@ cfgs: Dict[str, List[Union[str, int]]] = {
} }
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
if pretrained: if weights is not None:
kwargs["init_weights"] = False kwargs["init_weights"] = False
if weights.meta["categories"] is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained: if weights is not None:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(state_dict)
return model return model
def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: _COMMON_META = {
"task": "image_classification",
"architecture": "VGG",
"publication_year": 2014,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
}
class VGG11_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132863336,
"acc@1": 69.020,
"acc@5": 88.628,
},
)
DEFAULT = IMAGENET1K_V1
class VGG11_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132868840,
"acc@1": 70.370,
"acc@5": 89.810,
},
)
DEFAULT = IMAGENET1K_V1
class VGG13_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133047848,
"acc@1": 69.928,
"acc@5": 89.246,
},
)
DEFAULT = IMAGENET1K_V1
class VGG13_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133053736,
"acc@1": 71.586,
"acc@5": 90.374,
},
)
DEFAULT = IMAGENET1K_V1
class VGG16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138357544,
"acc@1": 71.592,
"acc@5": 90.382,
},
)
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the
# `classifier` module are filled with nans.
IMAGENET1K_FEATURES = Weights(
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial(
ImageClassification,
crop_size=224,
mean=(0.48235, 0.45882, 0.40784),
std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
),
meta={
**_COMMON_META,
"num_params": 138357544,
"categories": None,
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
"acc@1": float("nan"),
"acc@5": float("nan"),
},
)
DEFAULT = IMAGENET1K_V1
class VGG16_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138365992,
"acc@1": 73.360,
"acc@5": 91.516,
},
)
DEFAULT = IMAGENET1K_V1
class VGG19_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143667240,
"acc@1": 72.376,
"acc@5": 90.876,
},
)
DEFAULT = IMAGENET1K_V1
class VGG19_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143678248,
"acc@1": 74.218,
"acc@5": 91.842,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") from r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG11_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg11", "A", False, pretrained, progress, **kwargs) weights = VGG11_Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 11-layer model (configuration "A") with batch normalization r"""VGG 11-layer model (configuration "A") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG11_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs) weights = VGG11_BN_Weights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs)
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B") r"""VGG 13-layer model (configuration "B")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG13_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg13", "B", False, pretrained, progress, **kwargs) weights = VGG13_Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 13-layer model (configuration "B") with batch normalization r"""VGG 13-layer model (configuration "B") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG13_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs) weights = VGG13_BN_Weights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs)
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D") r"""VGG 16-layer model (configuration "D")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG16_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg16", "D", False, pretrained, progress, **kwargs) weights = VGG16_Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 16-layer model (configuration "D") with batch normalization r"""VGG 16-layer model (configuration "D") with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG16_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs) weights = VGG16_BN_Weights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs)
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration "E") r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG19_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) weights = VGG19_Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
r"""VGG 19-layer model (configuration 'E') with batch normalization r"""VGG 19-layer model (configuration 'E') with batch normalization
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
The required minimum input size of the model is 32x32. The required minimum input size of the model is 32x32.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (VGG19_BN_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs) weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs)
from functools import partial
from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ...transforms._presets import VideoClassification, InterpolationMode
from ...utils import _log_api_usage_once from ...utils import _log_api_usage_once
from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["r3d_18", "mc3_18", "r2plus1d_18"]
model_urls = { __all__ = [
"r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", "VideoResNet",
"mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", "R3D_18_Weights",
"r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", "MC3_18_Weights",
} "R2Plus1D_18_Weights",
"r3d_18",
"mc3_18",
"r2plus1d_18",
]
class Conv3DSimple(nn.Conv3d): class Conv3DSimple(nn.Conv3d):
...@@ -281,80 +288,152 @@ class VideoResNet(nn.Module): ...@@ -281,80 +288,152 @@ class VideoResNet(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: def _video_resnet(
model = VideoResNet(**kwargs) block: Type[Union[BasicBlock, Bottleneck]],
conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int],
stem: Callable[..., nn.Module],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> VideoResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model return model
def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: _COMMON_META = {
"task": "video_classification",
"publication_year": 2017,
"size": (112, 112),
"min_size": (1, 1),
"categories": _KINETICS400_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
}
class R3D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R3D",
"num_params": 33371472,
"acc@1": 52.75,
"acc@5": 75.45,
},
)
DEFAULT = KINETICS400_V1
class MC3_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "MC3",
"num_params": 11695440,
"acc@1": 53.90,
"acc@5": 76.29,
},
)
DEFAULT = KINETICS400_V1
class R2Plus1D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R(2+1)D",
"num_params": 31505325,
"acc@1": 57.50,
"acc@5": 78.81,
},
)
DEFAULT = KINETICS400_V1
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer Resnet3D model as in """Construct 18 layer Resnet3D model as in
https://arxiv.org/abs/1711.11248 https://arxiv.org/abs/1711.11248
Args: Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400 weights (R3D_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
Returns: Returns:
nn.Module: R3D-18 network VideoResNet: R3D-18 network
""" """
weights = R3D_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
"r3d_18", BasicBlock,
pretrained, [Conv3DSimple] * 4,
[2, 2, 2, 2],
BasicStem,
weights,
progress, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] * 4,
layers=[2, 2, 2, 2],
stem=BasicStem,
**kwargs, **kwargs,
) )
def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for 18 layer Mixed Convolution network as in """Constructor for 18 layer Mixed Convolution network as in
https://arxiv.org/abs/1711.11248 https://arxiv.org/abs/1711.11248
Args: Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400 weights (MC3_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
Returns: Returns:
nn.Module: MC3 Network definition VideoResNet: MC3 Network definition
""" """
weights = MC3_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
"mc3_18", BasicBlock,
pretrained, [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
[2, 2, 2, 2],
BasicStem,
weights,
progress, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
layers=[2, 2, 2, 2],
stem=BasicStem,
**kwargs, **kwargs,
) )
def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Constructor for the 18 layer deep R(2+1)D network as in """Constructor for the 18 layer deep R(2+1)D network as in
https://arxiv.org/abs/1711.11248 https://arxiv.org/abs/1711.11248
Args: Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400 weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
Returns: Returns:
nn.Module: R(2+1)D-18 network VideoResNet: R(2+1)D-18 network
""" """
weights = R2Plus1D_18_Weights.verify(weights)
return _video_resnet( return _video_resnet(
"r2plus1d_18", BasicBlock,
pretrained, [Conv2Plus1D] * 4,
[2, 2, 2, 2],
R2Plus1dStem,
weights,
progress, progress,
block=BasicBlock,
conv_makers=[Conv2Plus1D] * 4,
layers=[2, 2, 2, 2],
stem=R2Plus1dStem,
**kwargs, **kwargs,
) )
...@@ -6,25 +6,26 @@ from typing import Any, Callable, List, NamedTuple, Optional ...@@ -6,25 +6,26 @@ from typing import Any, Callable, List, NamedTuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import Conv2dNormActivation from ..ops.misc import Conv2dNormActivation
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
"VisionTransformer", "VisionTransformer",
"ViT_B_16_Weights",
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"vit_b_16", "vit_b_16",
"vit_b_32", "vit_b_32",
"vit_l_16", "vit_l_16",
"vit_l_32", "vit_l_32",
] ]
model_urls = {
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
}
class ConvStemConfig(NamedTuple): class ConvStemConfig(NamedTuple):
out_channels: int out_channels: int
...@@ -274,18 +275,20 @@ class VisionTransformer(nn.Module): ...@@ -274,18 +275,20 @@ class VisionTransformer(nn.Module):
def _vision_transformer( def _vision_transformer(
arch: str,
patch_size: int, patch_size: int,
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
hidden_dim: int, hidden_dim: int,
mlp_dim: int, mlp_dim: int,
pretrained: bool, weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> VisionTransformer: ) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224) image_size = kwargs.pop("image_size", 224)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VisionTransformer( model = VisionTransformer(
image_size=image_size, image_size=image_size,
patch_size=patch_size, patch_size=patch_size,
...@@ -296,98 +299,180 @@ def _vision_transformer( ...@@ -296,98 +299,180 @@ def _vision_transformer(
**kwargs, **kwargs,
) )
if pretrained: if weights:
if arch not in model_urls: model.load_state_dict(weights.get_state_dict(progress=progress))
raise ValueError(f"No checkpoint is available for model type '{arch}'!")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model return model
def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: _COMMON_META = {
"task": "image_classification",
"architecture": "ViT",
"publication_year": 2020,
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
"acc@1": 81.072,
"acc@5": 95.318,
},
)
DEFAULT = IMAGENET1K_V1
class ViT_B_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 88224232,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
"acc@1": 75.912,
"acc@5": 92.466,
},
)
DEFAULT = IMAGENET1K_V1
class ViT_L_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=242),
meta={
**_COMMON_META,
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
"acc@1": 79.662,
"acc@5": 94.638,
},
)
DEFAULT = IMAGENET1K_V1
class ViT_L_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"num_params": 306535400,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
"acc@1": 76.972,
"acc@5": 93.07,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_b_16 architecture from Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ViT_B_16_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
arch="vit_b_16",
patch_size=16, patch_size=16,
num_layers=12, num_layers=12,
num_heads=12, num_heads=12,
hidden_dim=768, hidden_dim=768,
mlp_dim=3072, mlp_dim=3072,
pretrained=pretrained, weights=weights,
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_b_32 architecture from Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ViT_B_32_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
arch="vit_b_32",
patch_size=32, patch_size=32,
num_layers=12, num_layers=12,
num_heads=12, num_heads=12,
hidden_dim=768, hidden_dim=768,
mlp_dim=3072, mlp_dim=3072,
pretrained=pretrained, weights=weights,
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_l_16 architecture from Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ViT_L_16_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
arch="vit_l_16",
patch_size=16, patch_size=16,
num_layers=24, num_layers=24,
num_heads=16, num_heads=16,
hidden_dim=1024, hidden_dim=1024,
mlp_dim=4096, mlp_dim=4096,
pretrained=pretrained, weights=weights,
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: @handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
""" """
Constructs a vit_l_32 architecture from Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ViT_L_32_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
arch="vit_l_32",
patch_size=32, patch_size=32,
num_layers=24, num_layers=24,
num_heads=16, num_heads=16,
hidden_dim=1024, hidden_dim=1024,
mlp_dim=4096, mlp_dim=4096,
pretrained=pretrained, weights=weights,
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment