Unverified Commit 3d8723d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor the `get_weights` API (#5006)

* Change the `default` weights mechanism to sue Enum aliases.

* Change `get_weights` to work with full Enum names and make it public.

* Applying improvements from code review.
parent 65cdaeab
......@@ -25,8 +25,8 @@ class MobileNet_V2_Weights(WeightsEnum):
"acc@1": 71.878,
"acc@5": 90.286,
},
default=True,
)
default = ImageNet1K_V1
def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
......
......@@ -54,7 +54,6 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@1": 74.042,
"acc@5": 91.340,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
......@@ -65,8 +64,8 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@1": 75.274,
"acc@5": 92.566,
},
default=True,
)
default = ImageNet1K_V2
class MobileNet_V3_Small_Weights(WeightsEnum):
......@@ -79,8 +78,8 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
"acc@1": 67.668,
"acc@5": 87.402,
},
default=True,
)
default = ImageNet1K_V1
def mobilenet_v3_large(
......
......@@ -38,8 +38,8 @@ class GoogLeNet_QuantizedWeights(WeightsEnum):
"acc@1": 69.826,
"acc@5": 89.404,
},
default=True,
)
default = ImageNet1K_FBGEMM_V1
def googlenet(
......
......@@ -37,8 +37,8 @@ class Inception_V3_QuantizedWeights(WeightsEnum):
"acc@1": 77.176,
"acc@5": 93.354,
},
default=True,
)
default = ImageNet1K_FBGEMM_V1
def inception_v3(
......
......@@ -38,8 +38,8 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
"acc@1": 71.658,
"acc@5": 90.150,
},
default=True,
)
default = ImageNet1K_QNNPACK_V1
def mobilenet_v2(
......
......@@ -71,8 +71,8 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
"acc@1": 73.004,
"acc@5": 90.858,
},
default=True,
)
default = ImageNet1K_QNNPACK_V1
def mobilenet_v3_large(
......
......@@ -73,8 +73,8 @@ class ResNet18_QuantizedWeights(WeightsEnum):
"acc@1": 69.494,
"acc@5": 88.882,
},
default=True,
)
default = ImageNet1K_FBGEMM_V1
class ResNet50_QuantizedWeights(WeightsEnum):
......@@ -87,7 +87,6 @@ class ResNet50_QuantizedWeights(WeightsEnum):
"acc@1": 75.920,
"acc@5": 92.814,
},
default=False,
)
ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
......@@ -98,8 +97,8 @@ class ResNet50_QuantizedWeights(WeightsEnum):
"acc@1": 80.282,
"acc@5": 94.976,
},
default=True,
)
default = ImageNet1K_FBGEMM_V2
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
......@@ -112,7 +111,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
"acc@1": 78.986,
"acc@5": 94.480,
},
default=False,
)
ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
......@@ -123,8 +121,8 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
"acc@1": 82.574,
"acc@5": 96.132,
},
default=True,
)
default = ImageNet1K_FBGEMM_V2
def resnet18(
......
......@@ -69,8 +69,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
"acc@1": 57.972,
"acc@5": 79.780,
},
default=True,
)
default = ImageNet1K_FBGEMM_V1
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
......@@ -83,8 +83,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
"acc@1": 68.360,
"acc@5": 87.582,
},
default=True,
)
default = ImageNet1K_FBGEMM_V1
def shufflenet_v2_x0_5(
......
......@@ -74,8 +74,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@1": 74.046,
"acc@5": 91.716,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_Y_800MF_Weights(WeightsEnum):
......@@ -88,8 +88,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@1": 76.420,
"acc@5": 93.136,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_Y_1_6GF_Weights(WeightsEnum):
......@@ -102,8 +102,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@1": 77.950,
"acc@5": 93.966,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_Y_3_2GF_Weights(WeightsEnum):
......@@ -116,8 +116,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@1": 78.948,
"acc@5": 94.576,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_Y_8GF_Weights(WeightsEnum):
......@@ -130,8 +130,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@1": 80.032,
"acc@5": 95.048,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_Y_16GF_Weights(WeightsEnum):
......@@ -144,8 +144,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@1": 80.424,
"acc@5": 95.240,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_Y_32GF_Weights(WeightsEnum):
......@@ -158,8 +158,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@1": 80.878,
"acc@5": 95.340,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_400MF_Weights(WeightsEnum):
......@@ -172,8 +172,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@1": 72.834,
"acc@5": 90.950,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_800MF_Weights(WeightsEnum):
......@@ -186,8 +186,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@1": 75.212,
"acc@5": 92.348,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_1_6GF_Weights(WeightsEnum):
......@@ -200,8 +200,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@1": 77.040,
"acc@5": 93.440,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_3_2GF_Weights(WeightsEnum):
......@@ -214,8 +214,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@1": 78.364,
"acc@5": 93.992,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_8GF_Weights(WeightsEnum):
......@@ -228,8 +228,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@1": 79.344,
"acc@5": 94.686,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_16GF_Weights(WeightsEnum):
......@@ -242,8 +242,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@1": 80.058,
"acc@5": 94.944,
},
default=True,
)
default = ImageNet1K_V1
class RegNet_X_32GF_Weights(WeightsEnum):
......@@ -256,8 +256,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@1": 80.622,
"acc@5": 95.248,
},
default=True,
)
default = ImageNet1K_V1
def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
......
......@@ -64,8 +64,8 @@ class ResNet18_Weights(WeightsEnum):
"acc@1": 69.758,
"acc@5": 89.078,
},
default=True,
)
default = ImageNet1K_V1
class ResNet34_Weights(WeightsEnum):
......@@ -78,8 +78,8 @@ class ResNet34_Weights(WeightsEnum):
"acc@1": 73.314,
"acc@5": 91.420,
},
default=True,
)
default = ImageNet1K_V1
class ResNet50_Weights(WeightsEnum):
......@@ -92,7 +92,6 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 76.130,
"acc@5": 92.862,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
......@@ -103,8 +102,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 80.674,
"acc@5": 95.166,
},
default=True,
)
default = ImageNet1K_V2
class ResNet101_Weights(WeightsEnum):
......@@ -117,7 +116,6 @@ class ResNet101_Weights(WeightsEnum):
"acc@1": 77.374,
"acc@5": 93.546,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
......@@ -128,8 +126,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@1": 81.886,
"acc@5": 95.780,
},
default=True,
)
default = ImageNet1K_V2
class ResNet152_Weights(WeightsEnum):
......@@ -142,7 +140,6 @@ class ResNet152_Weights(WeightsEnum):
"acc@1": 78.312,
"acc@5": 94.046,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
......@@ -153,8 +150,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@1": 82.284,
"acc@5": 96.002,
},
default=True,
)
default = ImageNet1K_V2
class ResNeXt50_32X4D_Weights(WeightsEnum):
......@@ -167,7 +164,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@1": 77.618,
"acc@5": 93.698,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
......@@ -178,8 +174,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@1": 81.198,
"acc@5": 95.340,
},
default=True,
)
default = ImageNet1K_V2
class ResNeXt101_32X8D_Weights(WeightsEnum):
......@@ -192,7 +188,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@1": 79.312,
"acc@5": 94.526,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
......@@ -203,8 +198,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@1": 82.834,
"acc@5": 96.228,
},
default=True,
)
default = ImageNet1K_V2
class Wide_ResNet50_2_Weights(WeightsEnum):
......@@ -217,7 +212,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@1": 78.468,
"acc@5": 94.086,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
......@@ -228,8 +222,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@1": 81.602,
"acc@5": 95.758,
},
default=True,
)
default = ImageNet1K_V2
class Wide_ResNet101_2_Weights(WeightsEnum):
......@@ -242,7 +236,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@1": 78.848,
"acc@5": 94.284,
},
default=False,
)
ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
......@@ -253,8 +246,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@1": 82.510,
"acc@5": 96.020,
},
default=True,
)
default = ImageNet1K_V2
def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
......
......@@ -40,8 +40,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
"mIoU": 66.4,
"acc": 92.4,
},
default=True,
)
default = CocoWithVocLabels_V1
class DeepLabV3_ResNet101_Weights(WeightsEnum):
......@@ -54,8 +54,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
"mIoU": 67.4,
"acc": 92.4,
},
default=True,
)
default = CocoWithVocLabels_V1
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
......@@ -68,8 +68,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
"mIoU": 60.3,
"acc": 91.2,
},
default=True,
)
default = CocoWithVocLabels_V1
def deeplabv3_resnet50(
......
......@@ -30,8 +30,8 @@ class FCN_ResNet50_Weights(WeightsEnum):
"mIoU": 60.5,
"acc": 91.4,
},
default=True,
)
default = CocoWithVocLabels_V1
class FCN_ResNet101_Weights(WeightsEnum):
......@@ -44,8 +44,8 @@ class FCN_ResNet101_Weights(WeightsEnum):
"mIoU": 63.7,
"acc": 91.9,
},
default=True,
)
default = CocoWithVocLabels_V1
def fcn_resnet50(
......
......@@ -25,8 +25,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
"mIoU": 57.9,
"acc": 91.2,
},
default=True,
)
default = CocoWithVocLabels_V1
def lraspp_mobilenet_v3_large(
......
......@@ -57,8 +57,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
"acc@1": 69.362,
"acc@5": 88.316,
},
default=True,
)
default = ImageNet1K_V1
class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
......@@ -70,8 +70,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
"acc@1": 60.552,
"acc@5": 81.746,
},
default=True,
)
default = ImageNet1K_V1
class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
......
......@@ -30,8 +30,8 @@ class SqueezeNet1_0_Weights(WeightsEnum):
"acc@1": 58.092,
"acc@5": 80.420,
},
default=True,
)
default = ImageNet1K_V1
class SqueezeNet1_1_Weights(WeightsEnum):
......@@ -43,8 +43,8 @@ class SqueezeNet1_1_Weights(WeightsEnum):
"acc@1": 58.178,
"acc@5": 80.624,
},
default=True,
)
default = ImageNet1K_V1
def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
......
......@@ -57,8 +57,8 @@ class VGG11_Weights(WeightsEnum):
"acc@1": 69.020,
"acc@5": 88.628,
},
default=True,
)
default = ImageNet1K_V1
class VGG11_BN_Weights(WeightsEnum):
......@@ -70,8 +70,8 @@ class VGG11_BN_Weights(WeightsEnum):
"acc@1": 70.370,
"acc@5": 89.810,
},
default=True,
)
default = ImageNet1K_V1
class VGG13_Weights(WeightsEnum):
......@@ -83,8 +83,8 @@ class VGG13_Weights(WeightsEnum):
"acc@1": 69.928,
"acc@5": 89.246,
},
default=True,
)
default = ImageNet1K_V1
class VGG13_BN_Weights(WeightsEnum):
......@@ -96,8 +96,8 @@ class VGG13_BN_Weights(WeightsEnum):
"acc@1": 71.586,
"acc@5": 90.374,
},
default=True,
)
default = ImageNet1K_V1
class VGG16_Weights(WeightsEnum):
......@@ -109,7 +109,6 @@ class VGG16_Weights(WeightsEnum):
"acc@1": 71.592,
"acc@5": 90.382,
},
default=True,
)
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the
......@@ -127,8 +126,8 @@ class VGG16_Weights(WeightsEnum):
"acc@1": float("nan"),
"acc@5": float("nan"),
},
default=False,
)
default = ImageNet1K_V1
class VGG16_BN_Weights(WeightsEnum):
......@@ -140,8 +139,8 @@ class VGG16_BN_Weights(WeightsEnum):
"acc@1": 73.360,
"acc@5": 91.516,
},
default=True,
)
default = ImageNet1K_V1
class VGG19_Weights(WeightsEnum):
......@@ -153,8 +152,8 @@ class VGG19_Weights(WeightsEnum):
"acc@1": 72.376,
"acc@5": 90.876,
},
default=True,
)
default = ImageNet1K_V1
class VGG19_BN_Weights(WeightsEnum):
......@@ -166,8 +165,8 @@ class VGG19_BN_Weights(WeightsEnum):
"acc@1": 74.218,
"acc@5": 91.842,
},
default=True,
)
default = ImageNet1K_V1
def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......
......@@ -68,8 +68,8 @@ class R3D_18_Weights(WeightsEnum):
"acc@1": 52.75,
"acc@5": 75.45,
},
default=True,
)
default = Kinetics400_V1
class MC3_18_Weights(WeightsEnum):
......@@ -81,8 +81,8 @@ class MC3_18_Weights(WeightsEnum):
"acc@1": 53.90,
"acc@5": 76.29,
},
default=True,
)
default = Kinetics400_V1
class R2Plus1D_18_Weights(WeightsEnum):
......@@ -94,8 +94,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
"acc@1": 57.50,
"acc@5": 78.81,
},
default=True,
)
default = Kinetics400_V1
def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment