Unverified Commit 18cf5ab1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

More Multiweight support cleanups (#4948)

* Updated the link for densenet recipe.

* Set default value of `num_classes` and `num_keypoints` to `None`

* Provide helper methods for parameter checks to reduce duplicate code.

* Throw errors on silent config overwrites from weight meta-data and legacy builders.

* Changing order of arguments + fixing mypy.

* Make the builders fully BC.

* Add "default" weights support that returns always the best weights.
parent 09e759ea
import warnings
from functools import partial
from typing import Any, List, Optional, Union
......@@ -14,6 +13,7 @@ from ....models.quantization.mobilenetv3 import (
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf
......@@ -33,9 +33,9 @@ def _mobilenet_v3_model(
**kwargs: Any,
) -> QuantizableMobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
......@@ -71,6 +71,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
"acc@1": 73.004,
"acc@5": 90.858,
},
default=True,
)
......@@ -80,17 +81,15 @@ def mobilenet_v3_large(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV3LargeWeights.ImageNet1K_RefV1
)
else:
weights = None
default_value = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedMobileNetV3LargeWeights.verify(weights)
else:
......
import warnings
from functools import partial
from typing import Any, List, Optional, Type, Union
......@@ -14,6 +13,7 @@ from ....models.quantization.resnet import (
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
......@@ -37,9 +37,9 @@ def _resnet(
**kwargs: Any,
) -> QuantizableResNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableResNet(block, layers, **kwargs)
......@@ -73,6 +73,7 @@ class QuantizedResNet18Weights(Weights):
"acc@1": 69.494,
"acc@5": 88.882,
},
default=True,
)
......@@ -86,6 +87,7 @@ class QuantizedResNet50Weights(Weights):
"acc@1": 75.920,
"acc@5": 92.814,
},
default=False,
)
ImageNet1K_FBGEMM_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
......@@ -96,6 +98,7 @@ class QuantizedResNet50Weights(Weights):
"acc@1": 80.282,
"acc@5": 94.976,
},
default=True,
)
......@@ -109,6 +112,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
"acc@1": 78.986,
"acc@5": 94.480,
},
default=False,
)
ImageNet1K_FBGEMM_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
......@@ -119,6 +123,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
"acc@1": 82.574,
"acc@5": 96.132,
},
default=True,
)
......@@ -128,13 +133,13 @@ def resnet18(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
else:
weights = None
default_value = (
QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedResNet18Weights.verify(weights)
else:
......@@ -149,13 +154,13 @@ def resnet50(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
else:
weights = None
default_value = (
QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedResNet50Weights.verify(weights)
else:
......@@ -170,22 +175,20 @@ def resnext101_32x8d(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
)
else:
weights = None
default_value = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
else:
weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
import warnings
from functools import partial
from typing import Any, List, Optional, Union
......@@ -12,6 +11,7 @@ from ....models.quantization.shufflenetv2 import (
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights
......@@ -33,9 +33,9 @@ def _shufflenetv2(
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
......@@ -69,6 +69,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
"acc@1": 57.972,
"acc@5": 79.780,
},
default=True,
)
......@@ -82,6 +83,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
"acc@1": 68.360,
"acc@5": 87.582,
},
default=True,
)
......@@ -91,17 +93,15 @@ def shufflenet_v2_x0_5(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
if quantize
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community
)
else:
weights = None
default_value = (
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
if quantize
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights)
else:
......@@ -116,17 +116,15 @@ def shufflenet_v2_x1_0(
quantize: bool = False,
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
if quantize
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community
)
else:
weights = None
default_value = (
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
if quantize
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community
)
weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
if quantize:
weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights)
else:
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -9,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -53,7 +53,7 @@ def _regnet(
**kwargs: Any,
) -> RegNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
......@@ -74,6 +74,7 @@ class RegNet_y_400mfWeights(Weights):
"acc@1": 74.046,
"acc@5": 91.716,
},
default=True,
)
......@@ -87,6 +88,7 @@ class RegNet_y_800mfWeights(Weights):
"acc@1": 76.420,
"acc@5": 93.136,
},
default=True,
)
......@@ -100,6 +102,7 @@ class RegNet_y_1_6gfWeights(Weights):
"acc@1": 77.950,
"acc@5": 93.966,
},
default=True,
)
......@@ -113,6 +116,7 @@ class RegNet_y_3_2gfWeights(Weights):
"acc@1": 78.948,
"acc@5": 94.576,
},
default=True,
)
......@@ -126,6 +130,7 @@ class RegNet_y_8gfWeights(Weights):
"acc@1": 80.032,
"acc@5": 95.048,
},
default=True,
)
......@@ -139,6 +144,7 @@ class RegNet_y_16gfWeights(Weights):
"acc@1": 80.424,
"acc@5": 95.240,
},
default=True,
)
......@@ -152,6 +158,7 @@ class RegNet_y_32gfWeights(Weights):
"acc@1": 80.878,
"acc@5": 95.340,
},
default=True,
)
......@@ -165,6 +172,7 @@ class RegNet_x_400mfWeights(Weights):
"acc@1": 72.834,
"acc@5": 90.950,
},
default=True,
)
......@@ -178,6 +186,7 @@ class RegNet_x_800mfWeights(Weights):
"acc@1": 75.212,
"acc@5": 92.348,
},
default=True,
)
......@@ -191,6 +200,7 @@ class RegNet_x_1_6gfWeights(Weights):
"acc@1": 77.040,
"acc@5": 93.440,
},
default=True,
)
......@@ -204,6 +214,7 @@ class RegNet_x_3_2gfWeights(Weights):
"acc@1": 78.364,
"acc@5": 93.992,
},
default=True,
)
......@@ -217,6 +228,7 @@ class RegNet_x_8gfWeights(Weights):
"acc@1": 79.344,
"acc@5": 94.686,
},
default=True,
)
......@@ -230,6 +242,7 @@ class RegNet_x_16gfWeights(Weights):
"acc@1": 80.058,
"acc@5": 94.944,
},
default=True,
)
......@@ -243,13 +256,15 @@ class RegNet_x_32gfWeights(Weights):
"acc@1": 80.622,
"acc@5": 95.248,
},
default=True,
)
def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_RefV1)
weights = RegNet_y_400mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
......@@ -257,9 +272,10 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo
def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_RefV1)
weights = RegNet_y_800mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
......@@ -267,9 +283,10 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo
def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_1_6gfWeights.verify(weights)
params = BlockParams.from_init_params(
......@@ -279,10 +296,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo
def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params(
depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
)
......@@ -290,10 +309,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo
def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_8gfWeights.verify(weights)
params = BlockParams.from_init_params(
depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
)
......@@ -301,10 +322,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool =
def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_16gfWeights.verify(weights)
params = BlockParams.from_init_params(
depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
)
......@@ -312,10 +335,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool
def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_32gfWeights.verify(weights)
params = BlockParams.from_init_params(
depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
)
......@@ -323,70 +348,77 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool
def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_RefV1)
weights = RegNet_x_400mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_RefV1)
weights = RegNet_x_800mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_1_6gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_8gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_16gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_32gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
return _regnet(params, weights, progress, **kwargs)
import warnings
from functools import partial
from typing import Any, List, Optional, Type, Union
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -41,7 +41,7 @@ def _resnet(
**kwargs: Any,
) -> ResNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet(block, layers, **kwargs)
......@@ -64,6 +64,7 @@ class ResNet18Weights(Weights):
"acc@1": 69.758,
"acc@5": 89.078,
},
default=True,
)
......@@ -77,6 +78,7 @@ class ResNet34Weights(Weights):
"acc@1": 73.314,
"acc@5": 91.420,
},
default=True,
)
......@@ -90,6 +92,7 @@ class ResNet50Weights(Weights):
"acc@1": 76.130,
"acc@5": 92.862,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
......@@ -100,6 +103,7 @@ class ResNet50Weights(Weights):
"acc@1": 80.674,
"acc@5": 95.166,
},
default=True,
)
......@@ -113,6 +117,7 @@ class ResNet101Weights(Weights):
"acc@1": 77.374,
"acc@5": 93.546,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
......@@ -123,6 +128,7 @@ class ResNet101Weights(Weights):
"acc@1": 81.886,
"acc@5": 95.780,
},
default=True,
)
......@@ -136,6 +142,7 @@ class ResNet152Weights(Weights):
"acc@1": 78.312,
"acc@5": 94.046,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
......@@ -146,6 +153,7 @@ class ResNet152Weights(Weights):
"acc@1": 82.284,
"acc@5": 96.002,
},
default=True,
)
......@@ -159,6 +167,7 @@ class ResNeXt50_32x4dWeights(Weights):
"acc@1": 77.618,
"acc@5": 93.698,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
......@@ -169,6 +178,7 @@ class ResNeXt50_32x4dWeights(Weights):
"acc@1": 81.198,
"acc@5": 95.340,
},
default=True,
)
......@@ -182,6 +192,7 @@ class ResNeXt101_32x8dWeights(Weights):
"acc@1": 79.312,
"acc@5": 94.526,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
......@@ -192,6 +203,7 @@ class ResNeXt101_32x8dWeights(Weights):
"acc@1": 82.834,
"acc@5": 96.228,
},
default=True,
)
......@@ -205,6 +217,7 @@ class WideResNet50_2Weights(Weights):
"acc@1": 78.468,
"acc@5": 94.086,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
......@@ -215,6 +228,7 @@ class WideResNet50_2Weights(Weights):
"acc@1": 81.602,
"acc@5": 95.758,
},
default=True,
)
......@@ -228,6 +242,7 @@ class WideResNet101_2Weights(Weights):
"acc@1": 78.848,
"acc@5": 94.284,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
......@@ -238,95 +253,101 @@ class WideResNet101_2Weights(Weights):
"acc@1": 82.510,
"acc@5": 96.020,
},
default=True,
)
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_RefV1)
weights = ResNet18Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_RefV1)
weights = ResNet34Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_RefV1)
weights = ResNet50Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_RefV1)
weights = ResNet101Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_RefV1)
weights = ResNet152Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1)
weights = ResNeXt50_32x4dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 4
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1)
weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
_ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = WideResNet50_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community)
weights = WideResNet50_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = WideResNet101_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community)
weights = WideResNet101_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import resnet50, resnet101
from ..resnet import ResNet50Weights, ResNet101Weights
......@@ -40,6 +40,7 @@ class DeepLabV3ResNet50Weights(Weights):
"mIoU": 66.4,
"acc": 92.4,
},
default=True,
)
......@@ -53,6 +54,7 @@ class DeepLabV3ResNet101Weights(Weights):
"mIoU": 67.4,
"acc": 92.4,
},
default=True,
)
......@@ -66,31 +68,37 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
"mIoU": 60.3,
"acc": 91.2,
},
default=True,
)
def deeplabv3_resnet50(
weights: Optional[DeepLabV3ResNet50Weights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 21,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1)
weights = DeepLabV3ResNet50Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
aux_loss = True
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
......@@ -103,26 +111,31 @@ def deeplabv3_resnet50(
def deeplabv3_resnet101(
weights: Optional[DeepLabV3ResNet101Weights] = None,
weights_backbone: Optional[ResNet101Weights] = None,
progress: bool = True,
num_classes: int = 21,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1)
weights = DeepLabV3ResNet101Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet101Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
aux_loss = True
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
......@@ -135,26 +148,33 @@ def deeplabv3_resnet101(
def deeplabv3_mobilenet_v3_large(
weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 21,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
**kwargs: Any,
) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(
kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1
)
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
aux_loss = True
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
......@@ -30,6 +30,7 @@ class FCNResNet50Weights(Weights):
"mIoU": 60.5,
"acc": 91.4,
},
default=True,
)
......@@ -43,30 +44,37 @@ class FCNResNet101Weights(Weights):
"mIoU": 63.7,
"acc": 91.9,
},
default=True,
)
def fcn_resnet50(
weights: Optional[FCNResNet50Weights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 21,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None,
**kwargs: Any,
) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FCNResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_RefV1)
weights = FCNResNet50Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
aux_loss = True
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
......@@ -79,25 +87,31 @@ def fcn_resnet50(
def fcn_resnet101(
weights: Optional[FCNResNet101Weights] = None,
weights_backbone: Optional[ResNet101Weights] = None,
progress: bool = True,
num_classes: int = 21,
num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None,
**kwargs: Any,
) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FCNResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_RefV1)
weights = FCNResNet101Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet101Weights.verify(weights_backbone)
if weights is not None:
aux_loss = True
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
......@@ -25,31 +25,40 @@ class LRASPPMobileNetV3LargeWeights(Weights):
"mIoU": 57.9,
"acc": 91.2,
},
default=True,
)
def lraspp_mobilenet_v3_large(
weights: Optional[LRASPPMobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 21,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
**kwargs: Any,
) -> LRASPP:
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(
kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1
)
weights = LRASPPMobileNetV3LargeWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -30,7 +30,7 @@ def _shufflenetv2(
**kwargs: Any,
) -> ShuffleNetV2:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ShuffleNetV2(*args, **kwargs)
......@@ -57,6 +57,7 @@ class ShuffleNetV2_x0_5Weights(Weights):
"acc@1": 69.362,
"acc@5": 88.316,
},
default=True,
)
......@@ -69,6 +70,7 @@ class ShuffleNetV2_x1_0Weights(Weights):
"acc@1": 60.552,
"acc@5": 81.746,
},
default=True,
)
......@@ -83,9 +85,10 @@ class ShuffleNetV2_x2_0Weights(Weights):
def shufflenet_v2_x0_5(
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_Community)
weights = ShuffleNetV2_x0_5Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
......@@ -94,9 +97,10 @@ def shufflenet_v2_x0_5(
def shufflenet_v2_x1_0(
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_Community)
weights = ShuffleNetV2_x1_0Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
......@@ -105,10 +109,10 @@ def shufflenet_v2_x1_0(
def shufflenet_v2_x1_5(
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x1_5")
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNetV2_x1_5Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
......@@ -117,10 +121,10 @@ def shufflenet_v2_x1_5(
def shufflenet_v2_x2_0(
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x2_0")
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = ShuffleNetV2_x2_0Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
......@@ -30,6 +30,7 @@ class SqueezeNet1_0Weights(Weights):
"acc@1": 58.092,
"acc@5": 80.420,
},
default=True,
)
......@@ -42,16 +43,19 @@ class SqueezeNet1_1Weights(Weights):
"acc@1": 58.178,
"acc@5": 80.624,
},
default=True,
)
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_Community)
weights = SqueezeNet1_0Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet("1_0", **kwargs)
......@@ -62,12 +66,14 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_Community)
weights = SqueezeNet1_1Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet("1_1", **kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional
......@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -33,7 +33,7 @@ __all__ = [
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
......@@ -57,6 +57,7 @@ class VGG11Weights(Weights):
"acc@1": 69.020,
"acc@5": 88.628,
},
default=True,
)
......@@ -69,6 +70,7 @@ class VGG11BNWeights(Weights):
"acc@1": 70.370,
"acc@5": 89.810,
},
default=True,
)
......@@ -81,6 +83,7 @@ class VGG13Weights(Weights):
"acc@1": 69.928,
"acc@5": 89.246,
},
default=True,
)
......@@ -93,6 +96,7 @@ class VGG13BNWeights(Weights):
"acc@1": 71.586,
"acc@5": 90.374,
},
default=True,
)
......@@ -105,6 +109,7 @@ class VGG16Weights(Weights):
"acc@1": 71.592,
"acc@5": 90.382,
},
default=True,
)
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the
......@@ -122,6 +127,7 @@ class VGG16Weights(Weights):
"acc@1": float("nan"),
"acc@5": float("nan"),
},
default=False,
)
......@@ -134,6 +140,7 @@ class VGG16BNWeights(Weights):
"acc@1": 73.360,
"acc@5": 91.516,
},
default=True,
)
......@@ -146,6 +153,7 @@ class VGG19Weights(Weights):
"acc@1": 72.376,
"acc@5": 90.876,
},
default=True,
)
......@@ -158,76 +166,85 @@ class VGG19BNWeights(Weights):
"acc@1": 74.218,
"acc@5": 91.842,
},
default=True,
)
def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_RefV1)
weights = VGG11Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_RefV1)
weights = VGG11BNWeights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_RefV1)
weights = VGG13Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_RefV1)
weights = VGG13BNWeights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_RefV1)
weights = VGG16Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_RefV1)
weights = VGG16BNWeights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_RefV1)
weights = VGG19Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_RefV1)
weights = VGG19BNWeights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs)
import warnings
from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union
......@@ -18,6 +17,7 @@ from ....models.video.resnet import (
)
from .._api import Weights, WeightEntry
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [
......@@ -41,7 +41,7 @@ def _video_resnet(
**kwargs: Any,
) -> VideoResNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
......@@ -68,6 +68,7 @@ class R3D_18Weights(Weights):
"acc@1": 52.75,
"acc@5": 75.45,
},
default=True,
)
......@@ -80,6 +81,7 @@ class MC3_18Weights(Weights):
"acc@1": 53.90,
"acc@5": 76.29,
},
default=True,
)
......@@ -92,13 +94,15 @@ class R2Plus1D_18Weights(Weights):
"acc@1": 57.50,
"acc@5": 78.81,
},
default=True,
)
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = R3D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_RefV1)
weights = R3D_18Weights.verify(weights)
return _video_resnet(
......@@ -113,9 +117,10 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MC3_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_RefV1)
weights = MC3_18Weights.verify(weights)
return _video_resnet(
......@@ -130,9 +135,10 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = R2Plus1D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_RefV1)
weights = R2Plus1D_18Weights.verify(weights)
return _video_resnet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment