Unverified Commit 18cf5ab1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

More Multiweight support cleanups (#4948)

* Updated the link for densenet recipe.

* Set default value of `num_classes` and `num_keypoints` to `None`

* Provide helper methods for parameter checks to reduce duplicate code.

* Throw errors on silent config overwrites from weight meta-data and legacy builders.

* Changing order of arguments + fixing mypy.

* Make the builders fully BC.

* Add "default" weights support that returns always the best weights.
parent 09e759ea
import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
...@@ -14,6 +13,7 @@ from ....models.quantization.mobilenetv3 import ( ...@@ -14,6 +13,7 @@ from ....models.quantization.mobilenetv3 import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf
...@@ -33,9 +33,9 @@ def _mobilenet_v3_model( ...@@ -33,9 +33,9 @@ def _mobilenet_v3_model(
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV3: ) -> QuantizableMobileNetV3:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta: if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"] _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack") backend = kwargs.pop("backend", "qnnpack")
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
...@@ -71,6 +71,7 @@ class QuantizedMobileNetV3LargeWeights(Weights): ...@@ -71,6 +71,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
"acc@1": 73.004, "acc@1": 73.004,
"acc@5": 90.858, "acc@5": 90.858,
}, },
default=True,
) )
...@@ -80,17 +81,15 @@ def mobilenet_v3_large( ...@@ -80,17 +81,15 @@ def mobilenet_v3_large(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableMobileNetV3: ) -> QuantizableMobileNetV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
weights = ( if quantize
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 else MobileNetV3LargeWeights.ImageNet1K_RefV1
if quantize )
else MobileNetV3LargeWeights.ImageNet1K_RefV1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
)
else:
weights = None
if quantize: if quantize:
weights = QuantizedMobileNetV3LargeWeights.verify(weights) weights = QuantizedMobileNetV3LargeWeights.verify(weights)
else: else:
......
import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
...@@ -14,6 +13,7 @@ from ....models.quantization.resnet import ( ...@@ -14,6 +13,7 @@ from ....models.quantization.resnet import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
...@@ -37,9 +37,9 @@ def _resnet( ...@@ -37,9 +37,9 @@ def _resnet(
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta: if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"] _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm") backend = kwargs.pop("backend", "fbgemm")
model = QuantizableResNet(block, layers, **kwargs) model = QuantizableResNet(block, layers, **kwargs)
...@@ -73,6 +73,7 @@ class QuantizedResNet18Weights(Weights): ...@@ -73,6 +73,7 @@ class QuantizedResNet18Weights(Weights):
"acc@1": 69.494, "acc@1": 69.494,
"acc@5": 88.882, "acc@5": 88.882,
}, },
default=True,
) )
...@@ -86,6 +87,7 @@ class QuantizedResNet50Weights(Weights): ...@@ -86,6 +87,7 @@ class QuantizedResNet50Weights(Weights):
"acc@1": 75.920, "acc@1": 75.920,
"acc@5": 92.814, "acc@5": 92.814,
}, },
default=False,
) )
ImageNet1K_FBGEMM_RefV2 = WeightEntry( ImageNet1K_FBGEMM_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
...@@ -96,6 +98,7 @@ class QuantizedResNet50Weights(Weights): ...@@ -96,6 +98,7 @@ class QuantizedResNet50Weights(Weights):
"acc@1": 80.282, "acc@1": 80.282,
"acc@5": 94.976, "acc@5": 94.976,
}, },
default=True,
) )
...@@ -109,6 +112,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): ...@@ -109,6 +112,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
"acc@1": 78.986, "acc@1": 78.986,
"acc@5": 94.480, "acc@5": 94.480,
}, },
default=False,
) )
ImageNet1K_FBGEMM_RefV2 = WeightEntry( ImageNet1K_FBGEMM_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
...@@ -119,6 +123,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): ...@@ -119,6 +123,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
"acc@1": 82.574, "acc@1": 82.574,
"acc@5": 96.132, "acc@5": 96.132,
}, },
default=True,
) )
...@@ -128,13 +133,13 @@ def resnet18( ...@@ -128,13 +133,13 @@ def resnet18(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1 )
else: weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
weights = None
if quantize: if quantize:
weights = QuantizedResNet18Weights.verify(weights) weights = QuantizedResNet18Weights.verify(weights)
else: else:
...@@ -149,13 +154,13 @@ def resnet50( ...@@ -149,13 +154,13 @@ def resnet50(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1 )
else: weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
weights = None
if quantize: if quantize:
weights = QuantizedResNet50Weights.verify(weights) weights = QuantizedResNet50Weights.verify(weights)
else: else:
...@@ -170,22 +175,20 @@ def resnext101_32x8d( ...@@ -170,22 +175,20 @@ def resnext101_32x8d(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableResNet: ) -> QuantizableResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
weights = ( if quantize
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1 else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
if quantize )
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
)
else:
weights = None
if quantize: if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights) weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
else: else:
weights = ResNeXt101_32x8dWeights.verify(weights) weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32 _ovewrite_named_param(kwargs, "groups", 32)
kwargs["width_per_group"] = 8 _ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
...@@ -12,6 +11,7 @@ from ....models.quantization.shufflenetv2 import ( ...@@ -12,6 +11,7 @@ from ....models.quantization.shufflenetv2 import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights
...@@ -33,9 +33,9 @@ def _shufflenetv2( ...@@ -33,9 +33,9 @@ def _shufflenetv2(
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta: if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"] _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm") backend = kwargs.pop("backend", "fbgemm")
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
...@@ -69,6 +69,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights): ...@@ -69,6 +69,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
"acc@1": 57.972, "acc@1": 57.972,
"acc@5": 79.780, "acc@5": 79.780,
}, },
default=True,
) )
...@@ -82,6 +83,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights): ...@@ -82,6 +83,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
"acc@1": 68.360, "acc@1": 68.360,
"acc@5": 87.582, "acc@5": 87.582,
}, },
default=True,
) )
...@@ -91,17 +93,15 @@ def shufflenet_v2_x0_5( ...@@ -91,17 +93,15 @@ def shufflenet_v2_x0_5(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
weights = ( if quantize
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community else ShuffleNetV2_x0_5Weights.ImageNet1K_Community
if quantize )
else ShuffleNetV2_x0_5Weights.ImageNet1K_Community weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
)
else:
weights = None
if quantize: if quantize:
weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights) weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights)
else: else:
...@@ -116,17 +116,15 @@ def shufflenet_v2_x1_0( ...@@ -116,17 +116,15 @@ def shufflenet_v2_x1_0(
quantize: bool = False, quantize: bool = False,
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") default_value = (
if kwargs.pop("pretrained"): QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
weights = ( if quantize
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community else ShuffleNetV2_x1_0Weights.ImageNet1K_Community
if quantize )
else ShuffleNetV2_x1_0Weights.ImageNet1K_Community weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment]
)
else:
weights = None
if quantize: if quantize:
weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights) weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights)
else: else:
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -9,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -9,6 +8,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams from ...models.regnet import RegNet, BlockParams
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -53,7 +53,7 @@ def _regnet( ...@@ -53,7 +53,7 @@ def _regnet(
**kwargs: Any, **kwargs: Any,
) -> RegNet: ) -> RegNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
model = RegNet(block_params, norm_layer=norm_layer, **kwargs) model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
...@@ -74,6 +74,7 @@ class RegNet_y_400mfWeights(Weights): ...@@ -74,6 +74,7 @@ class RegNet_y_400mfWeights(Weights):
"acc@1": 74.046, "acc@1": 74.046,
"acc@5": 91.716, "acc@5": 91.716,
}, },
default=True,
) )
...@@ -87,6 +88,7 @@ class RegNet_y_800mfWeights(Weights): ...@@ -87,6 +88,7 @@ class RegNet_y_800mfWeights(Weights):
"acc@1": 76.420, "acc@1": 76.420,
"acc@5": 93.136, "acc@5": 93.136,
}, },
default=True,
) )
...@@ -100,6 +102,7 @@ class RegNet_y_1_6gfWeights(Weights): ...@@ -100,6 +102,7 @@ class RegNet_y_1_6gfWeights(Weights):
"acc@1": 77.950, "acc@1": 77.950,
"acc@5": 93.966, "acc@5": 93.966,
}, },
default=True,
) )
...@@ -113,6 +116,7 @@ class RegNet_y_3_2gfWeights(Weights): ...@@ -113,6 +116,7 @@ class RegNet_y_3_2gfWeights(Weights):
"acc@1": 78.948, "acc@1": 78.948,
"acc@5": 94.576, "acc@5": 94.576,
}, },
default=True,
) )
...@@ -126,6 +130,7 @@ class RegNet_y_8gfWeights(Weights): ...@@ -126,6 +130,7 @@ class RegNet_y_8gfWeights(Weights):
"acc@1": 80.032, "acc@1": 80.032,
"acc@5": 95.048, "acc@5": 95.048,
}, },
default=True,
) )
...@@ -139,6 +144,7 @@ class RegNet_y_16gfWeights(Weights): ...@@ -139,6 +144,7 @@ class RegNet_y_16gfWeights(Weights):
"acc@1": 80.424, "acc@1": 80.424,
"acc@5": 95.240, "acc@5": 95.240,
}, },
default=True,
) )
...@@ -152,6 +158,7 @@ class RegNet_y_32gfWeights(Weights): ...@@ -152,6 +158,7 @@ class RegNet_y_32gfWeights(Weights):
"acc@1": 80.878, "acc@1": 80.878,
"acc@5": 95.340, "acc@5": 95.340,
}, },
default=True,
) )
...@@ -165,6 +172,7 @@ class RegNet_x_400mfWeights(Weights): ...@@ -165,6 +172,7 @@ class RegNet_x_400mfWeights(Weights):
"acc@1": 72.834, "acc@1": 72.834,
"acc@5": 90.950, "acc@5": 90.950,
}, },
default=True,
) )
...@@ -178,6 +186,7 @@ class RegNet_x_800mfWeights(Weights): ...@@ -178,6 +186,7 @@ class RegNet_x_800mfWeights(Weights):
"acc@1": 75.212, "acc@1": 75.212,
"acc@5": 92.348, "acc@5": 92.348,
}, },
default=True,
) )
...@@ -191,6 +200,7 @@ class RegNet_x_1_6gfWeights(Weights): ...@@ -191,6 +200,7 @@ class RegNet_x_1_6gfWeights(Weights):
"acc@1": 77.040, "acc@1": 77.040,
"acc@5": 93.440, "acc@5": 93.440,
}, },
default=True,
) )
...@@ -204,6 +214,7 @@ class RegNet_x_3_2gfWeights(Weights): ...@@ -204,6 +214,7 @@ class RegNet_x_3_2gfWeights(Weights):
"acc@1": 78.364, "acc@1": 78.364,
"acc@5": 93.992, "acc@5": 93.992,
}, },
default=True,
) )
...@@ -217,6 +228,7 @@ class RegNet_x_8gfWeights(Weights): ...@@ -217,6 +228,7 @@ class RegNet_x_8gfWeights(Weights):
"acc@1": 79.344, "acc@1": 79.344,
"acc@5": 94.686, "acc@5": 94.686,
}, },
default=True,
) )
...@@ -230,6 +242,7 @@ class RegNet_x_16gfWeights(Weights): ...@@ -230,6 +242,7 @@ class RegNet_x_16gfWeights(Weights):
"acc@1": 80.058, "acc@1": 80.058,
"acc@5": 94.944, "acc@5": 94.944,
}, },
default=True,
) )
...@@ -243,13 +256,15 @@ class RegNet_x_32gfWeights(Weights): ...@@ -243,13 +256,15 @@ class RegNet_x_32gfWeights(Weights):
"acc@1": 80.622, "acc@1": 80.622,
"acc@5": 95.248, "acc@5": 95.248,
}, },
default=True,
) )
def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_RefV1)
weights = RegNet_y_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_400mfWeights.verify(weights) weights = RegNet_y_400mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
...@@ -257,9 +272,10 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo ...@@ -257,9 +272,10 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo
def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_RefV1)
weights = RegNet_y_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_800mfWeights.verify(weights) weights = RegNet_y_800mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
...@@ -267,9 +283,10 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo ...@@ -267,9 +283,10 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo
def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_1_6gfWeights.verify(weights) weights = RegNet_y_1_6gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -279,10 +296,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo ...@@ -279,10 +296,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo
def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_3_2gfWeights.verify(weights) weights = RegNet_y_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
) )
...@@ -290,10 +309,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo ...@@ -290,10 +309,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo
def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_8gfWeights.verify(weights) weights = RegNet_y_8gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
) )
...@@ -301,10 +322,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = ...@@ -301,10 +322,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool =
def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_16gfWeights.verify(weights) weights = RegNet_y_16gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
) )
...@@ -312,10 +335,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool ...@@ -312,10 +335,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool
def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_RefV1)
weights = RegNet_y_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_32gfWeights.verify(weights) weights = RegNet_y_32gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
) )
...@@ -323,70 +348,77 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool ...@@ -323,70 +348,77 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool
def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_RefV1)
weights = RegNet_x_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_400mfWeights.verify(weights) weights = RegNet_x_400mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_RefV1)
weights = RegNet_x_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_800mfWeights.verify(weights) weights = RegNet_x_800mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_1_6gfWeights.verify(weights) weights = RegNet_x_1_6gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_3_2gfWeights.verify(weights) weights = RegNet_x_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_8gfWeights.verify(weights) weights = RegNet_x_8gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_16gfWeights.verify(weights) weights = RegNet_x_16gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_RefV1)
weights = RegNet_x_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_32gfWeights.verify(weights) weights = RegNet_x_32gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -41,7 +41,7 @@ def _resnet( ...@@ -41,7 +41,7 @@ def _resnet(
**kwargs: Any, **kwargs: Any,
) -> ResNet: ) -> ResNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
...@@ -64,6 +64,7 @@ class ResNet18Weights(Weights): ...@@ -64,6 +64,7 @@ class ResNet18Weights(Weights):
"acc@1": 69.758, "acc@1": 69.758,
"acc@5": 89.078, "acc@5": 89.078,
}, },
default=True,
) )
...@@ -77,6 +78,7 @@ class ResNet34Weights(Weights): ...@@ -77,6 +78,7 @@ class ResNet34Weights(Weights):
"acc@1": 73.314, "acc@1": 73.314,
"acc@5": 91.420, "acc@5": 91.420,
}, },
default=True,
) )
...@@ -90,6 +92,7 @@ class ResNet50Weights(Weights): ...@@ -90,6 +92,7 @@ class ResNet50Weights(Weights):
"acc@1": 76.130, "acc@1": 76.130,
"acc@5": 92.862, "acc@5": 92.862,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
...@@ -100,6 +103,7 @@ class ResNet50Weights(Weights): ...@@ -100,6 +103,7 @@ class ResNet50Weights(Weights):
"acc@1": 80.674, "acc@1": 80.674,
"acc@5": 95.166, "acc@5": 95.166,
}, },
default=True,
) )
...@@ -113,6 +117,7 @@ class ResNet101Weights(Weights): ...@@ -113,6 +117,7 @@ class ResNet101Weights(Weights):
"acc@1": 77.374, "acc@1": 77.374,
"acc@5": 93.546, "acc@5": 93.546,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
...@@ -123,6 +128,7 @@ class ResNet101Weights(Weights): ...@@ -123,6 +128,7 @@ class ResNet101Weights(Weights):
"acc@1": 81.886, "acc@1": 81.886,
"acc@5": 95.780, "acc@5": 95.780,
}, },
default=True,
) )
...@@ -136,6 +142,7 @@ class ResNet152Weights(Weights): ...@@ -136,6 +142,7 @@ class ResNet152Weights(Weights):
"acc@1": 78.312, "acc@1": 78.312,
"acc@5": 94.046, "acc@5": 94.046,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth", url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
...@@ -146,6 +153,7 @@ class ResNet152Weights(Weights): ...@@ -146,6 +153,7 @@ class ResNet152Weights(Weights):
"acc@1": 82.284, "acc@1": 82.284,
"acc@5": 96.002, "acc@5": 96.002,
}, },
default=True,
) )
...@@ -159,6 +167,7 @@ class ResNeXt50_32x4dWeights(Weights): ...@@ -159,6 +167,7 @@ class ResNeXt50_32x4dWeights(Weights):
"acc@1": 77.618, "acc@1": 77.618,
"acc@5": 93.698, "acc@5": 93.698,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
...@@ -169,6 +178,7 @@ class ResNeXt50_32x4dWeights(Weights): ...@@ -169,6 +178,7 @@ class ResNeXt50_32x4dWeights(Weights):
"acc@1": 81.198, "acc@1": 81.198,
"acc@5": 95.340, "acc@5": 95.340,
}, },
default=True,
) )
...@@ -182,6 +192,7 @@ class ResNeXt101_32x8dWeights(Weights): ...@@ -182,6 +192,7 @@ class ResNeXt101_32x8dWeights(Weights):
"acc@1": 79.312, "acc@1": 79.312,
"acc@5": 94.526, "acc@5": 94.526,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
...@@ -192,6 +203,7 @@ class ResNeXt101_32x8dWeights(Weights): ...@@ -192,6 +203,7 @@ class ResNeXt101_32x8dWeights(Weights):
"acc@1": 82.834, "acc@1": 82.834,
"acc@5": 96.228, "acc@5": 96.228,
}, },
default=True,
) )
...@@ -205,6 +217,7 @@ class WideResNet50_2Weights(Weights): ...@@ -205,6 +217,7 @@ class WideResNet50_2Weights(Weights):
"acc@1": 78.468, "acc@1": 78.468,
"acc@5": 94.086, "acc@5": 94.086,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
...@@ -215,6 +228,7 @@ class WideResNet50_2Weights(Weights): ...@@ -215,6 +228,7 @@ class WideResNet50_2Weights(Weights):
"acc@1": 81.602, "acc@1": 81.602,
"acc@5": 95.758, "acc@5": 95.758,
}, },
default=True,
) )
...@@ -228,6 +242,7 @@ class WideResNet101_2Weights(Weights): ...@@ -228,6 +242,7 @@ class WideResNet101_2Weights(Weights):
"acc@1": 78.848, "acc@1": 78.848,
"acc@5": 94.284, "acc@5": 94.284,
}, },
default=False,
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
...@@ -238,95 +253,101 @@ class WideResNet101_2Weights(Weights): ...@@ -238,95 +253,101 @@ class WideResNet101_2Weights(Weights):
"acc@1": 82.510, "acc@1": 82.510,
"acc@5": 96.020, "acc@5": 96.020,
}, },
default=True,
) )
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_RefV1)
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet18Weights.verify(weights) weights = ResNet18Weights.verify(weights)
return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_RefV1)
weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet34Weights.verify(weights) weights = ResNet34Weights.verify(weights)
return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_RefV1)
weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet50Weights.verify(weights) weights = ResNet50Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_RefV1)
weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet101Weights.verify(weights) weights = ResNet101Weights.verify(weights)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_RefV1)
weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet152Weights.verify(weights) weights = ResNet152Weights.verify(weights)
return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1)
weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNeXt50_32x4dWeights.verify(weights) weights = ResNeXt50_32x4dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 4 _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 4)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1)
weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNeXt101_32x8dWeights.verify(weights) weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8 _ovewrite_named_param(kwargs, "groups", 32)
_ovewrite_named_param(kwargs, "width_per_group", 8)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community)
weights = WideResNet50_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet50_2Weights.verify(weights) weights = WideResNet50_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community)
weights = WideResNet101_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet101_2Weights.verify(weights) weights = WideResNet101_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import resnet50, resnet101 from ..resnet import resnet50, resnet101
from ..resnet import ResNet50Weights, ResNet101Weights from ..resnet import ResNet50Weights, ResNet101Weights
...@@ -40,6 +40,7 @@ class DeepLabV3ResNet50Weights(Weights): ...@@ -40,6 +40,7 @@ class DeepLabV3ResNet50Weights(Weights):
"mIoU": 66.4, "mIoU": 66.4,
"acc": 92.4, "acc": 92.4,
}, },
default=True,
) )
...@@ -53,6 +54,7 @@ class DeepLabV3ResNet101Weights(Weights): ...@@ -53,6 +54,7 @@ class DeepLabV3ResNet101Weights(Weights):
"mIoU": 67.4, "mIoU": 67.4,
"acc": 92.4, "acc": 92.4,
}, },
default=True,
) )
...@@ -66,31 +68,37 @@ class DeepLabV3MobileNetV3LargeWeights(Weights): ...@@ -66,31 +68,37 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
"mIoU": 60.3, "mIoU": 60.3,
"acc": 91.2, "acc": 91.2,
}, },
default=True,
) )
def deeplabv3_resnet50( def deeplabv3_resnet50(
weights: Optional[DeepLabV3ResNet50Weights] = None, weights: Optional[DeepLabV3ResNet50Weights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1)
weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3ResNet50Weights.verify(weights) weights = DeepLabV3ResNet50Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
aux_loss = True num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_classes = len(weights.meta["categories"]) aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
...@@ -103,26 +111,31 @@ def deeplabv3_resnet50( ...@@ -103,26 +111,31 @@ def deeplabv3_resnet50(
def deeplabv3_resnet101( def deeplabv3_resnet101(
weights: Optional[DeepLabV3ResNet101Weights] = None, weights: Optional[DeepLabV3ResNet101Weights] = None,
weights_backbone: Optional[ResNet101Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1)
weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3ResNet101Weights.verify(weights) weights = DeepLabV3ResNet101Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet101Weights.verify(weights_backbone) weights_backbone = ResNet101Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
aux_loss = True num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_classes = len(weights.meta["categories"]) aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
...@@ -135,26 +148,33 @@ def deeplabv3_resnet101( ...@@ -135,26 +148,33 @@ def deeplabv3_resnet101(
def deeplabv3_mobilenet_v3_large( def deeplabv3_mobilenet_v3_large(
weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None, weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(
weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1
)
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
aux_loss = True num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_classes = len(weights.meta["categories"]) aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet from ....models.segmentation.fcn import FCN, _fcn_resnet
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
...@@ -30,6 +30,7 @@ class FCNResNet50Weights(Weights): ...@@ -30,6 +30,7 @@ class FCNResNet50Weights(Weights):
"mIoU": 60.5, "mIoU": 60.5,
"acc": 91.4, "acc": 91.4,
}, },
default=True,
) )
...@@ -43,30 +44,37 @@ class FCNResNet101Weights(Weights): ...@@ -43,30 +44,37 @@ class FCNResNet101Weights(Weights):
"mIoU": 63.7, "mIoU": 63.7,
"acc": 91.9, "acc": 91.9,
}, },
default=True,
) )
def fcn_resnet50( def fcn_resnet50(
weights: Optional[FCNResNet50Weights] = None, weights: Optional[FCNResNet50Weights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet50Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_RefV1)
weights = FCNResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet50Weights.verify(weights) weights = FCNResNet50Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
aux_loss = True
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
...@@ -79,25 +87,31 @@ def fcn_resnet50( ...@@ -79,25 +87,31 @@ def fcn_resnet50(
def fcn_resnet101( def fcn_resnet101(
weights: Optional[FCNResNet101Weights] = None, weights: Optional[FCNResNet101Weights] = None,
weights_backbone: Optional[ResNet101Weights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
weights_backbone: Optional[ResNet101Weights] = None,
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_RefV1)
weights = FCNResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet101Weights.verify(weights) weights = FCNResNet101Weights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1
)
weights_backbone = ResNet101Weights.verify(weights_backbone) weights_backbone = ResNet101Weights.verify(weights_backbone)
if weights is not None: if weights is not None:
aux_loss = True
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
aux_loss = _ovewrite_value_param(aux_loss, True)
elif num_classes is None:
num_classes = 21
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
...@@ -25,31 +25,40 @@ class LRASPPMobileNetV3LargeWeights(Weights): ...@@ -25,31 +25,40 @@ class LRASPPMobileNetV3LargeWeights(Weights):
"mIoU": 57.9, "mIoU": 57.9,
"acc": 91.2, "acc": 91.2,
}, },
default=True,
) )
def lraspp_mobilenet_v3_large( def lraspp_mobilenet_v3_large(
weights: Optional[LRASPPMobileNetV3LargeWeights] = None, weights: Optional[LRASPPMobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
**kwargs: Any, **kwargs: Any,
) -> LRASPP: ) -> LRASPP:
if kwargs.pop("aux_loss", False): if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss") raise NotImplementedError("This model does not use auxiliary loss")
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(
weights = LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1
)
weights = LRASPPMobileNetV3LargeWeights.verify(weights) weights = LRASPPMobileNetV3LargeWeights.verify(weights)
if type(weights_backbone) == bool and weights_backbone:
_deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.") weights_backbone = _deprecated_param(
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if weights is not None: if weights is not None:
weights_backbone = None weights_backbone = None
num_classes = len(weights.meta["categories"]) num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes) model = _lraspp_mobilenetv3(backbone, num_classes)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2 from ...models.shufflenetv2 import ShuffleNetV2
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -30,7 +30,7 @@ def _shufflenetv2( ...@@ -30,7 +30,7 @@ def _shufflenetv2(
**kwargs: Any, **kwargs: Any,
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
...@@ -57,6 +57,7 @@ class ShuffleNetV2_x0_5Weights(Weights): ...@@ -57,6 +57,7 @@ class ShuffleNetV2_x0_5Weights(Weights):
"acc@1": 69.362, "acc@1": 69.362,
"acc@5": 88.316, "acc@5": 88.316,
}, },
default=True,
) )
...@@ -69,6 +70,7 @@ class ShuffleNetV2_x1_0Weights(Weights): ...@@ -69,6 +70,7 @@ class ShuffleNetV2_x1_0Weights(Weights):
"acc@1": 60.552, "acc@1": 60.552,
"acc@5": 81.746, "acc@5": 81.746,
}, },
default=True,
) )
...@@ -83,9 +85,10 @@ class ShuffleNetV2_x2_0Weights(Weights): ...@@ -83,9 +85,10 @@ class ShuffleNetV2_x2_0Weights(Weights):
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_Community)
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x0_5Weights.verify(weights) weights = ShuffleNetV2_x0_5Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
...@@ -94,9 +97,10 @@ def shufflenet_v2_x0_5( ...@@ -94,9 +97,10 @@ def shufflenet_v2_x0_5(
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_Community)
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x1_0Weights.verify(weights) weights = ShuffleNetV2_x1_0Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
...@@ -105,10 +109,10 @@ def shufflenet_v2_x1_0( ...@@ -105,10 +109,10 @@ def shufflenet_v2_x1_0(
def shufflenet_v2_x1_5( def shufflenet_v2_x1_5(
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", None)
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x1_5")
weights = ShuffleNetV2_x1_5Weights.verify(weights) weights = ShuffleNetV2_x1_5Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
...@@ -117,10 +121,10 @@ def shufflenet_v2_x1_5( ...@@ -117,10 +121,10 @@ def shufflenet_v2_x1_5(
def shufflenet_v2_x2_0( def shufflenet_v2_x2_0(
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", None)
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x2_0")
weights = ShuffleNetV2_x2_0Weights.verify(weights) weights = ShuffleNetV2_x2_0Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet from ...models.squeezenet import SqueezeNet
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] __all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
...@@ -30,6 +30,7 @@ class SqueezeNet1_0Weights(Weights): ...@@ -30,6 +30,7 @@ class SqueezeNet1_0Weights(Weights):
"acc@1": 58.092, "acc@1": 58.092,
"acc@5": 80.420, "acc@5": 80.420,
}, },
default=True,
) )
...@@ -42,16 +43,19 @@ class SqueezeNet1_1Weights(Weights): ...@@ -42,16 +43,19 @@ class SqueezeNet1_1Weights(Weights):
"acc@1": 58.178, "acc@1": 58.178,
"acc@5": 80.624, "acc@5": 80.624,
}, },
default=True,
) )
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_Community)
weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_0Weights.verify(weights) weights = SqueezeNet1_0Weights.verify(weights)
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet("1_0", **kwargs) model = SqueezeNet("1_0", **kwargs)
...@@ -62,12 +66,14 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool ...@@ -62,12 +66,14 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_Community)
weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_1Weights.verify(weights) weights = SqueezeNet1_1Weights.verify(weights)
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SqueezeNet("1_1", **kwargs) model = SqueezeNet("1_1", **kwargs)
......
import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -8,6 +7,7 @@ from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs from ...models.vgg import VGG, make_layers, cfgs
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -33,7 +33,7 @@ __all__ = [ ...@@ -33,7 +33,7 @@ __all__ = [
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
...@@ -57,6 +57,7 @@ class VGG11Weights(Weights): ...@@ -57,6 +57,7 @@ class VGG11Weights(Weights):
"acc@1": 69.020, "acc@1": 69.020,
"acc@5": 88.628, "acc@5": 88.628,
}, },
default=True,
) )
...@@ -69,6 +70,7 @@ class VGG11BNWeights(Weights): ...@@ -69,6 +70,7 @@ class VGG11BNWeights(Weights):
"acc@1": 70.370, "acc@1": 70.370,
"acc@5": 89.810, "acc@5": 89.810,
}, },
default=True,
) )
...@@ -81,6 +83,7 @@ class VGG13Weights(Weights): ...@@ -81,6 +83,7 @@ class VGG13Weights(Weights):
"acc@1": 69.928, "acc@1": 69.928,
"acc@5": 89.246, "acc@5": 89.246,
}, },
default=True,
) )
...@@ -93,6 +96,7 @@ class VGG13BNWeights(Weights): ...@@ -93,6 +96,7 @@ class VGG13BNWeights(Weights):
"acc@1": 71.586, "acc@1": 71.586,
"acc@5": 90.374, "acc@5": 90.374,
}, },
default=True,
) )
...@@ -105,6 +109,7 @@ class VGG16Weights(Weights): ...@@ -105,6 +109,7 @@ class VGG16Weights(Weights):
"acc@1": 71.592, "acc@1": 71.592,
"acc@5": 90.382, "acc@5": 90.382,
}, },
default=True,
) )
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the # same input standardization method as the paper. Only the `features` weights have proper values, those on the
...@@ -122,6 +127,7 @@ class VGG16Weights(Weights): ...@@ -122,6 +127,7 @@ class VGG16Weights(Weights):
"acc@1": float("nan"), "acc@1": float("nan"),
"acc@5": float("nan"), "acc@5": float("nan"),
}, },
default=False,
) )
...@@ -134,6 +140,7 @@ class VGG16BNWeights(Weights): ...@@ -134,6 +140,7 @@ class VGG16BNWeights(Weights):
"acc@1": 73.360, "acc@1": 73.360,
"acc@5": 91.516, "acc@5": 91.516,
}, },
default=True,
) )
...@@ -146,6 +153,7 @@ class VGG19Weights(Weights): ...@@ -146,6 +153,7 @@ class VGG19Weights(Weights):
"acc@1": 72.376, "acc@1": 72.376,
"acc@5": 90.876, "acc@5": 90.876,
}, },
default=True,
) )
...@@ -158,76 +166,85 @@ class VGG19BNWeights(Weights): ...@@ -158,76 +166,85 @@ class VGG19BNWeights(Weights):
"acc@1": 74.218, "acc@1": 74.218,
"acc@5": 91.842, "acc@5": 91.842,
}, },
default=True,
) )
def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_RefV1)
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11Weights.verify(weights) weights = VGG11Weights.verify(weights)
return _vgg("A", False, weights, progress, **kwargs) return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_RefV1)
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11BNWeights.verify(weights) weights = VGG11BNWeights.verify(weights)
return _vgg("A", True, weights, progress, **kwargs) return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_RefV1)
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13Weights.verify(weights) weights = VGG13Weights.verify(weights)
return _vgg("B", False, weights, progress, **kwargs) return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_RefV1)
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13BNWeights.verify(weights) weights = VGG13BNWeights.verify(weights)
return _vgg("B", True, weights, progress, **kwargs) return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_RefV1)
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16Weights.verify(weights) weights = VGG16Weights.verify(weights)
return _vgg("D", False, weights, progress, **kwargs) return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_RefV1)
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16BNWeights.verify(weights) weights = VGG16BNWeights.verify(weights)
return _vgg("D", True, weights, progress, **kwargs) return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_RefV1)
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19Weights.verify(weights) weights = VGG19Weights.verify(weights)
return _vgg("E", False, weights, progress, **kwargs) return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_RefV1)
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19BNWeights.verify(weights) weights = VGG19BNWeights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs) return _vgg("E", True, weights, progress, **kwargs)
import warnings
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union from typing import Any, Callable, List, Optional, Sequence, Type, Union
...@@ -18,6 +17,7 @@ from ....models.video.resnet import ( ...@@ -18,6 +17,7 @@ from ....models.video.resnet import (
) )
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _KINETICS400_CATEGORIES from .._meta import _KINETICS400_CATEGORIES
from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
__all__ = [ __all__ = [
...@@ -41,7 +41,7 @@ def _video_resnet( ...@@ -41,7 +41,7 @@ def _video_resnet(
**kwargs: Any, **kwargs: Any,
) -> VideoResNet: ) -> VideoResNet:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VideoResNet(block, conv_makers, layers, stem, **kwargs) model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
...@@ -68,6 +68,7 @@ class R3D_18Weights(Weights): ...@@ -68,6 +68,7 @@ class R3D_18Weights(Weights):
"acc@1": 52.75, "acc@1": 52.75,
"acc@5": 75.45, "acc@5": 75.45,
}, },
default=True,
) )
...@@ -80,6 +81,7 @@ class MC3_18Weights(Weights): ...@@ -80,6 +81,7 @@ class MC3_18Weights(Weights):
"acc@1": 53.90, "acc@1": 53.90,
"acc@5": 76.29, "acc@5": 76.29,
}, },
default=True,
) )
...@@ -92,13 +94,15 @@ class R2Plus1D_18Weights(Weights): ...@@ -92,13 +94,15 @@ class R2Plus1D_18Weights(Weights):
"acc@1": 57.50, "acc@1": 57.50,
"acc@5": 78.81, "acc@5": 78.81,
}, },
default=True,
) )
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_RefV1)
weights = R3D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R3D_18Weights.verify(weights) weights = R3D_18Weights.verify(weights)
return _video_resnet( return _video_resnet(
...@@ -113,9 +117,10 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa ...@@ -113,9 +117,10 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_RefV1)
weights = MC3_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = MC3_18Weights.verify(weights) weights = MC3_18Weights.verify(weights)
return _video_resnet( return _video_resnet(
...@@ -130,9 +135,10 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa ...@@ -130,9 +135,10 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.") weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_RefV1)
weights = R2Plus1D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R2Plus1D_18Weights.verify(weights) weights = R2Plus1D_18Weights.verify(weights)
return _video_resnet( return _video_resnet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment