Unverified Commit 3fa24148 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add SWAG model weight that only the linear head is finetuned to ImageNet1K (#5793)

* Add SWAG model that only the linear classifier head is finetuned with frozen trunk weight

* Add accuracy from experiments

* Change name from SWAG_LC to SWAG_LINEAR

* Add comment on SWAG_LINEAR weight

* Remove the comment docs (moved to PR description), and add the PR url as recipe. Also change name of previous swag model to SWAG_E2E_V1
parent c399c3f4
......@@ -575,7 +575,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 96.328,
},
)
IMAGENET1K_SWAG_V1 = Weights(
IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
......@@ -587,6 +587,19 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 98.054,
},
)
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 83590140,
"acc@1": 83.976,
"acc@5": 97.244,
},
)
DEFAULT = IMAGENET1K_V2
......@@ -613,7 +626,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 96.498,
},
)
IMAGENET1K_SWAG_V1 = Weights(
IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
......@@ -625,11 +638,24 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 98.362,
},
)
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 145046770,
"acc@1": 84.622,
"acc@5": 97.480,
},
)
DEFAULT = IMAGENET1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights(
IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
......@@ -641,7 +667,20 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5": 98.682,
},
)
DEFAULT = IMAGENET1K_SWAG_V1
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 644812894,
"acc@1": 86.068,
"acc@5": 97.844,
},
)
DEFAULT = IMAGENET1K_SWAG_E2E_V1
class RegNet_X_400MF_Weights(WeightsEnum):
......
......@@ -349,7 +349,7 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318,
},
)
IMAGENET1K_SWAG_V1 = Weights(
IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
transforms=partial(
ImageClassification,
......@@ -366,6 +366,24 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 97.650,
},
)
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 81.886,
"acc@5": 96.180,
},
)
DEFAULT = IMAGENET1K_V1
......@@ -400,7 +418,7 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 94.638,
},
)
IMAGENET1K_SWAG_V1 = Weights(
IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
transforms=partial(
ImageClassification,
......@@ -417,6 +435,24 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 98.512,
},
)
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.146,
"acc@5": 97.422,
},
)
DEFAULT = IMAGENET1K_V1
......@@ -438,7 +474,7 @@ class ViT_L_32_Weights(WeightsEnum):
class ViT_H_14_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights(
IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
transforms=partial(
ImageClassification,
......@@ -455,7 +491,25 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5": 98.694,
},
)
DEFAULT = IMAGENET1K_SWAG_V1
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 632045800,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.708,
"acc@5": 97.730,
},
)
DEFAULT = IMAGENET1K_SWAG_E2E_V1
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment