Unverified Commit 3d8723d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor the `get_weights` API (#5006)

* Change the `default` weights mechanism to sue Enum aliases.

* Change `get_weights` to work with full Enum names and make it public.

* Applying improvements from code review.
parent 65cdaeab
...@@ -25,8 +25,8 @@ class MobileNet_V2_Weights(WeightsEnum): ...@@ -25,8 +25,8 @@ class MobileNet_V2_Weights(WeightsEnum):
"acc@1": 71.878, "acc@1": 71.878,
"acc@5": 90.286, "acc@5": 90.286,
}, },
default=True,
) )
default = ImageNet1K_V1
def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
......
...@@ -54,7 +54,6 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -54,7 +54,6 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@1": 74.042, "acc@1": 74.042,
"acc@5": 91.340, "acc@5": 91.340,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
...@@ -65,8 +64,8 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -65,8 +64,8 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
"acc@1": 75.274, "acc@1": 75.274,
"acc@5": 92.566, "acc@5": 92.566,
}, },
default=True,
) )
default = ImageNet1K_V2
class MobileNet_V3_Small_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum):
...@@ -79,8 +78,8 @@ class MobileNet_V3_Small_Weights(WeightsEnum): ...@@ -79,8 +78,8 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
"acc@1": 67.668, "acc@1": 67.668,
"acc@5": 87.402, "acc@5": 87.402,
}, },
default=True,
) )
default = ImageNet1K_V1
def mobilenet_v3_large( def mobilenet_v3_large(
......
...@@ -38,8 +38,8 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): ...@@ -38,8 +38,8 @@ class GoogLeNet_QuantizedWeights(WeightsEnum):
"acc@1": 69.826, "acc@1": 69.826,
"acc@5": 89.404, "acc@5": 89.404,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V1
def googlenet( def googlenet(
......
...@@ -37,8 +37,8 @@ class Inception_V3_QuantizedWeights(WeightsEnum): ...@@ -37,8 +37,8 @@ class Inception_V3_QuantizedWeights(WeightsEnum):
"acc@1": 77.176, "acc@1": 77.176,
"acc@5": 93.354, "acc@5": 93.354,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V1
def inception_v3( def inception_v3(
......
...@@ -38,8 +38,8 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): ...@@ -38,8 +38,8 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
"acc@1": 71.658, "acc@1": 71.658,
"acc@5": 90.150, "acc@5": 90.150,
}, },
default=True,
) )
default = ImageNet1K_QNNPACK_V1
def mobilenet_v2( def mobilenet_v2(
......
...@@ -71,8 +71,8 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): ...@@ -71,8 +71,8 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
"acc@1": 73.004, "acc@1": 73.004,
"acc@5": 90.858, "acc@5": 90.858,
}, },
default=True,
) )
default = ImageNet1K_QNNPACK_V1
def mobilenet_v3_large( def mobilenet_v3_large(
......
...@@ -73,8 +73,8 @@ class ResNet18_QuantizedWeights(WeightsEnum): ...@@ -73,8 +73,8 @@ class ResNet18_QuantizedWeights(WeightsEnum):
"acc@1": 69.494, "acc@1": 69.494,
"acc@5": 88.882, "acc@5": 88.882,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V1
class ResNet50_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum):
...@@ -87,7 +87,6 @@ class ResNet50_QuantizedWeights(WeightsEnum): ...@@ -87,7 +87,6 @@ class ResNet50_QuantizedWeights(WeightsEnum):
"acc@1": 75.920, "acc@1": 75.920,
"acc@5": 92.814, "acc@5": 92.814,
}, },
default=False,
) )
ImageNet1K_FBGEMM_V2 = Weights( ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
...@@ -98,8 +97,8 @@ class ResNet50_QuantizedWeights(WeightsEnum): ...@@ -98,8 +97,8 @@ class ResNet50_QuantizedWeights(WeightsEnum):
"acc@1": 80.282, "acc@1": 80.282,
"acc@5": 94.976, "acc@5": 94.976,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V2
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
...@@ -112,7 +111,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ...@@ -112,7 +111,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
"acc@1": 78.986, "acc@1": 78.986,
"acc@5": 94.480, "acc@5": 94.480,
}, },
default=False,
) )
ImageNet1K_FBGEMM_V2 = Weights( ImageNet1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
...@@ -123,8 +121,8 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ...@@ -123,8 +121,8 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
"acc@1": 82.574, "acc@1": 82.574,
"acc@5": 96.132, "acc@5": 96.132,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V2
def resnet18( def resnet18(
......
...@@ -69,8 +69,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): ...@@ -69,8 +69,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
"acc@1": 57.972, "acc@1": 57.972,
"acc@5": 79.780, "acc@5": 79.780,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V1
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
...@@ -83,8 +83,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): ...@@ -83,8 +83,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
"acc@1": 68.360, "acc@1": 68.360,
"acc@5": 87.582, "acc@5": 87.582,
}, },
default=True,
) )
default = ImageNet1K_FBGEMM_V1
def shufflenet_v2_x0_5( def shufflenet_v2_x0_5(
......
...@@ -74,8 +74,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ...@@ -74,8 +74,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@1": 74.046, "acc@1": 74.046,
"acc@5": 91.716, "acc@5": 91.716,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum):
...@@ -88,8 +88,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ...@@ -88,8 +88,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@1": 76.420, "acc@1": 76.420,
"acc@5": 93.136, "acc@5": 93.136,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum):
...@@ -102,8 +102,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ...@@ -102,8 +102,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@1": 77.950, "acc@1": 77.950,
"acc@5": 93.966, "acc@5": 93.966,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum):
...@@ -116,8 +116,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ...@@ -116,8 +116,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@1": 78.948, "acc@1": 78.948,
"acc@5": 94.576, "acc@5": 94.576,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum):
...@@ -130,8 +130,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ...@@ -130,8 +130,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@1": 80.032, "acc@1": 80.032,
"acc@5": 95.048, "acc@5": 95.048,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum):
...@@ -144,8 +144,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -144,8 +144,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@1": 80.424, "acc@1": 80.424,
"acc@5": 95.240, "acc@5": 95.240,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_Y_32GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum):
...@@ -158,8 +158,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -158,8 +158,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@1": 80.878, "acc@1": 80.878,
"acc@5": 95.340, "acc@5": 95.340,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum):
...@@ -172,8 +172,8 @@ class RegNet_X_400MF_Weights(WeightsEnum): ...@@ -172,8 +172,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@1": 72.834, "acc@1": 72.834,
"acc@5": 90.950, "acc@5": 90.950,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum):
...@@ -186,8 +186,8 @@ class RegNet_X_800MF_Weights(WeightsEnum): ...@@ -186,8 +186,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@1": 75.212, "acc@1": 75.212,
"acc@5": 92.348, "acc@5": 92.348,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum):
...@@ -200,8 +200,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ...@@ -200,8 +200,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@1": 77.040, "acc@1": 77.040,
"acc@5": 93.440, "acc@5": 93.440,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum):
...@@ -214,8 +214,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ...@@ -214,8 +214,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@1": 78.364, "acc@1": 78.364,
"acc@5": 93.992, "acc@5": 93.992,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum):
...@@ -228,8 +228,8 @@ class RegNet_X_8GF_Weights(WeightsEnum): ...@@ -228,8 +228,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@1": 79.344, "acc@1": 79.344,
"acc@5": 94.686, "acc@5": 94.686,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum):
...@@ -242,8 +242,8 @@ class RegNet_X_16GF_Weights(WeightsEnum): ...@@ -242,8 +242,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@1": 80.058, "acc@1": 80.058,
"acc@5": 94.944, "acc@5": 94.944,
}, },
default=True,
) )
default = ImageNet1K_V1
class RegNet_X_32GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum):
...@@ -256,8 +256,8 @@ class RegNet_X_32GF_Weights(WeightsEnum): ...@@ -256,8 +256,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@1": 80.622, "acc@1": 80.622,
"acc@5": 95.248, "acc@5": 95.248,
}, },
default=True,
) )
default = ImageNet1K_V1
def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
......
...@@ -64,8 +64,8 @@ class ResNet18_Weights(WeightsEnum): ...@@ -64,8 +64,8 @@ class ResNet18_Weights(WeightsEnum):
"acc@1": 69.758, "acc@1": 69.758,
"acc@5": 89.078, "acc@5": 89.078,
}, },
default=True,
) )
default = ImageNet1K_V1
class ResNet34_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum):
...@@ -78,8 +78,8 @@ class ResNet34_Weights(WeightsEnum): ...@@ -78,8 +78,8 @@ class ResNet34_Weights(WeightsEnum):
"acc@1": 73.314, "acc@1": 73.314,
"acc@5": 91.420, "acc@5": 91.420,
}, },
default=True,
) )
default = ImageNet1K_V1
class ResNet50_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum):
...@@ -92,7 +92,6 @@ class ResNet50_Weights(WeightsEnum): ...@@ -92,7 +92,6 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 76.130, "acc@1": 76.130,
"acc@5": 92.862, "acc@5": 92.862,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
...@@ -103,8 +102,8 @@ class ResNet50_Weights(WeightsEnum): ...@@ -103,8 +102,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@1": 80.674, "acc@1": 80.674,
"acc@5": 95.166, "acc@5": 95.166,
}, },
default=True,
) )
default = ImageNet1K_V2
class ResNet101_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum):
...@@ -117,7 +116,6 @@ class ResNet101_Weights(WeightsEnum): ...@@ -117,7 +116,6 @@ class ResNet101_Weights(WeightsEnum):
"acc@1": 77.374, "acc@1": 77.374,
"acc@5": 93.546, "acc@5": 93.546,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
...@@ -128,8 +126,8 @@ class ResNet101_Weights(WeightsEnum): ...@@ -128,8 +126,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@1": 81.886, "acc@1": 81.886,
"acc@5": 95.780, "acc@5": 95.780,
}, },
default=True,
) )
default = ImageNet1K_V2
class ResNet152_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum):
...@@ -142,7 +140,6 @@ class ResNet152_Weights(WeightsEnum): ...@@ -142,7 +140,6 @@ class ResNet152_Weights(WeightsEnum):
"acc@1": 78.312, "acc@1": 78.312,
"acc@5": 94.046, "acc@5": 94.046,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth", url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
...@@ -153,8 +150,8 @@ class ResNet152_Weights(WeightsEnum): ...@@ -153,8 +150,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@1": 82.284, "acc@1": 82.284,
"acc@5": 96.002, "acc@5": 96.002,
}, },
default=True,
) )
default = ImageNet1K_V2
class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum):
...@@ -167,7 +164,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -167,7 +164,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@1": 77.618, "acc@1": 77.618,
"acc@5": 93.698, "acc@5": 93.698,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
...@@ -178,8 +174,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -178,8 +174,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@1": 81.198, "acc@1": 81.198,
"acc@5": 95.340, "acc@5": 95.340,
}, },
default=True,
) )
default = ImageNet1K_V2
class ResNeXt101_32X8D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum):
...@@ -192,7 +188,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -192,7 +188,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@1": 79.312, "acc@1": 79.312,
"acc@5": 94.526, "acc@5": 94.526,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
...@@ -203,8 +198,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -203,8 +198,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@1": 82.834, "acc@1": 82.834,
"acc@5": 96.228, "acc@5": 96.228,
}, },
default=True,
) )
default = ImageNet1K_V2
class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum):
...@@ -217,7 +212,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -217,7 +212,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@1": 78.468, "acc@1": 78.468,
"acc@5": 94.086, "acc@5": 94.086,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
...@@ -228,8 +222,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -228,8 +222,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@1": 81.602, "acc@1": 81.602,
"acc@5": 95.758, "acc@5": 95.758,
}, },
default=True,
) )
default = ImageNet1K_V2
class Wide_ResNet101_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum):
...@@ -242,7 +236,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -242,7 +236,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@1": 78.848, "acc@1": 78.848,
"acc@5": 94.284, "acc@5": 94.284,
}, },
default=False,
) )
ImageNet1K_V2 = Weights( ImageNet1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
...@@ -253,8 +246,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -253,8 +246,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@1": 82.510, "acc@1": 82.510,
"acc@5": 96.020, "acc@5": 96.020,
}, },
default=True,
) )
default = ImageNet1K_V2
def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
......
...@@ -40,8 +40,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): ...@@ -40,8 +40,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
"mIoU": 66.4, "mIoU": 66.4,
"acc": 92.4, "acc": 92.4,
}, },
default=True,
) )
default = CocoWithVocLabels_V1
class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum):
...@@ -54,8 +54,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): ...@@ -54,8 +54,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
"mIoU": 67.4, "mIoU": 67.4,
"acc": 92.4, "acc": 92.4,
}, },
default=True,
) )
default = CocoWithVocLabels_V1
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
...@@ -68,8 +68,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -68,8 +68,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
"mIoU": 60.3, "mIoU": 60.3,
"acc": 91.2, "acc": 91.2,
}, },
default=True,
) )
default = CocoWithVocLabels_V1
def deeplabv3_resnet50( def deeplabv3_resnet50(
......
...@@ -30,8 +30,8 @@ class FCN_ResNet50_Weights(WeightsEnum): ...@@ -30,8 +30,8 @@ class FCN_ResNet50_Weights(WeightsEnum):
"mIoU": 60.5, "mIoU": 60.5,
"acc": 91.4, "acc": 91.4,
}, },
default=True,
) )
default = CocoWithVocLabels_V1
class FCN_ResNet101_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum):
...@@ -44,8 +44,8 @@ class FCN_ResNet101_Weights(WeightsEnum): ...@@ -44,8 +44,8 @@ class FCN_ResNet101_Weights(WeightsEnum):
"mIoU": 63.7, "mIoU": 63.7,
"acc": 91.9, "acc": 91.9,
}, },
default=True,
) )
default = CocoWithVocLabels_V1
def fcn_resnet50( def fcn_resnet50(
......
...@@ -25,8 +25,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -25,8 +25,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
"mIoU": 57.9, "mIoU": 57.9,
"acc": 91.2, "acc": 91.2,
}, },
default=True,
) )
default = CocoWithVocLabels_V1
def lraspp_mobilenet_v3_large( def lraspp_mobilenet_v3_large(
......
...@@ -57,8 +57,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): ...@@ -57,8 +57,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
"acc@1": 69.362, "acc@1": 69.362,
"acc@5": 88.316, "acc@5": 88.316,
}, },
default=True,
) )
default = ImageNet1K_V1
class ShuffleNet_V2_X1_0_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
...@@ -70,8 +70,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum): ...@@ -70,8 +70,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
"acc@1": 60.552, "acc@1": 60.552,
"acc@5": 81.746, "acc@5": 81.746,
}, },
default=True,
) )
default = ImageNet1K_V1
class ShuffleNet_V2_X1_5_Weights(WeightsEnum): class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
......
...@@ -30,8 +30,8 @@ class SqueezeNet1_0_Weights(WeightsEnum): ...@@ -30,8 +30,8 @@ class SqueezeNet1_0_Weights(WeightsEnum):
"acc@1": 58.092, "acc@1": 58.092,
"acc@5": 80.420, "acc@5": 80.420,
}, },
default=True,
) )
default = ImageNet1K_V1
class SqueezeNet1_1_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum):
...@@ -43,8 +43,8 @@ class SqueezeNet1_1_Weights(WeightsEnum): ...@@ -43,8 +43,8 @@ class SqueezeNet1_1_Weights(WeightsEnum):
"acc@1": 58.178, "acc@1": 58.178,
"acc@5": 80.624, "acc@5": 80.624,
}, },
default=True,
) )
default = ImageNet1K_V1
def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
......
...@@ -57,8 +57,8 @@ class VGG11_Weights(WeightsEnum): ...@@ -57,8 +57,8 @@ class VGG11_Weights(WeightsEnum):
"acc@1": 69.020, "acc@1": 69.020,
"acc@5": 88.628, "acc@5": 88.628,
}, },
default=True,
) )
default = ImageNet1K_V1
class VGG11_BN_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum):
...@@ -70,8 +70,8 @@ class VGG11_BN_Weights(WeightsEnum): ...@@ -70,8 +70,8 @@ class VGG11_BN_Weights(WeightsEnum):
"acc@1": 70.370, "acc@1": 70.370,
"acc@5": 89.810, "acc@5": 89.810,
}, },
default=True,
) )
default = ImageNet1K_V1
class VGG13_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum):
...@@ -83,8 +83,8 @@ class VGG13_Weights(WeightsEnum): ...@@ -83,8 +83,8 @@ class VGG13_Weights(WeightsEnum):
"acc@1": 69.928, "acc@1": 69.928,
"acc@5": 89.246, "acc@5": 89.246,
}, },
default=True,
) )
default = ImageNet1K_V1
class VGG13_BN_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum):
...@@ -96,8 +96,8 @@ class VGG13_BN_Weights(WeightsEnum): ...@@ -96,8 +96,8 @@ class VGG13_BN_Weights(WeightsEnum):
"acc@1": 71.586, "acc@1": 71.586,
"acc@5": 90.374, "acc@5": 90.374,
}, },
default=True,
) )
default = ImageNet1K_V1
class VGG16_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum):
...@@ -109,7 +109,6 @@ class VGG16_Weights(WeightsEnum): ...@@ -109,7 +109,6 @@ class VGG16_Weights(WeightsEnum):
"acc@1": 71.592, "acc@1": 71.592,
"acc@5": 90.382, "acc@5": 90.382,
}, },
default=True,
) )
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Only the `features` weights have proper values, those on the # same input standardization method as the paper. Only the `features` weights have proper values, those on the
...@@ -127,8 +126,8 @@ class VGG16_Weights(WeightsEnum): ...@@ -127,8 +126,8 @@ class VGG16_Weights(WeightsEnum):
"acc@1": float("nan"), "acc@1": float("nan"),
"acc@5": float("nan"), "acc@5": float("nan"),
}, },
default=False,
) )
default = ImageNet1K_V1
class VGG16_BN_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum):
...@@ -140,8 +139,8 @@ class VGG16_BN_Weights(WeightsEnum): ...@@ -140,8 +139,8 @@ class VGG16_BN_Weights(WeightsEnum):
"acc@1": 73.360, "acc@1": 73.360,
"acc@5": 91.516, "acc@5": 91.516,
}, },
default=True,
) )
default = ImageNet1K_V1
class VGG19_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum):
...@@ -153,8 +152,8 @@ class VGG19_Weights(WeightsEnum): ...@@ -153,8 +152,8 @@ class VGG19_Weights(WeightsEnum):
"acc@1": 72.376, "acc@1": 72.376,
"acc@5": 90.876, "acc@5": 90.876,
}, },
default=True,
) )
default = ImageNet1K_V1
class VGG19_BN_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum):
...@@ -166,8 +165,8 @@ class VGG19_BN_Weights(WeightsEnum): ...@@ -166,8 +165,8 @@ class VGG19_BN_Weights(WeightsEnum):
"acc@1": 74.218, "acc@1": 74.218,
"acc@5": 91.842, "acc@5": 91.842,
}, },
default=True,
) )
default = ImageNet1K_V1
def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......
...@@ -68,8 +68,8 @@ class R3D_18_Weights(WeightsEnum): ...@@ -68,8 +68,8 @@ class R3D_18_Weights(WeightsEnum):
"acc@1": 52.75, "acc@1": 52.75,
"acc@5": 75.45, "acc@5": 75.45,
}, },
default=True,
) )
default = Kinetics400_V1
class MC3_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum):
...@@ -81,8 +81,8 @@ class MC3_18_Weights(WeightsEnum): ...@@ -81,8 +81,8 @@ class MC3_18_Weights(WeightsEnum):
"acc@1": 53.90, "acc@1": 53.90,
"acc@5": 76.29, "acc@5": 76.29,
}, },
default=True,
) )
default = Kinetics400_V1
class R2Plus1D_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum):
...@@ -94,8 +94,8 @@ class R2Plus1D_18_Weights(WeightsEnum): ...@@ -94,8 +94,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
"acc@1": 57.50, "acc@1": 57.50,
"acc@5": 78.81, "acc@5": 78.81,
}, },
default=True,
) )
default = Kinetics400_V1
def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment