You need to sign in or sign up before continuing.
Unverified Commit 3fa24148 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add SWAG model weight that only the linear head is finetuned to ImageNet1K (#5793)

* Add SWAG model that only the linear classifier head is finetuned with frozen trunk weight

* Add accuracy from experiments

* Change name from SWAG_LC to SWAG_LINEAR

* Add comment on SWAG_LINEAR weight

* Remove the comment docs (moved to PR description), and add the PR url as recipe. Also change name of previous swag model to SWAG_E2E_V1
parent c399c3f4
...@@ -575,7 +575,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -575,7 +575,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 96.328, "acc@5": 96.328,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth", url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
...@@ -587,6 +587,19 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -587,6 +587,19 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 98.054, "acc@5": 98.054,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 83590140,
"acc@1": 83.976,
"acc@5": 97.244,
},
)
DEFAULT = IMAGENET1K_V2 DEFAULT = IMAGENET1K_V2
...@@ -613,7 +626,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -613,7 +626,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 96.498, "acc@5": 96.498,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth", url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
...@@ -625,11 +638,24 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -625,11 +638,24 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 98.362, "acc@5": 98.362,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 145046770,
"acc@1": 84.622,
"acc@5": 97.480,
},
)
DEFAULT = IMAGENET1K_V2 DEFAULT = IMAGENET1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_Y_128GF_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth", url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
...@@ -641,7 +667,20 @@ class RegNet_Y_128GF_Weights(WeightsEnum): ...@@ -641,7 +667,20 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5": 98.682, "acc@5": 98.682,
}, },
) )
DEFAULT = IMAGENET1K_SWAG_V1 IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 644812894,
"acc@1": 86.068,
"acc@5": 97.844,
},
)
DEFAULT = IMAGENET1K_SWAG_E2E_V1
class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum):
......
...@@ -349,7 +349,7 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -349,7 +349,7 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318, "acc@5": 95.318,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
transforms=partial( transforms=partial(
ImageClassification, ImageClassification,
...@@ -366,6 +366,24 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -366,6 +366,24 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 97.650, "acc@5": 97.650,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 81.886,
"acc@5": 96.180,
},
)
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -400,7 +418,7 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -400,7 +418,7 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 94.638, "acc@5": 94.638,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth", url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
transforms=partial( transforms=partial(
ImageClassification, ImageClassification,
...@@ -417,6 +435,24 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -417,6 +435,24 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 98.512, "acc@5": 98.512,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.146,
"acc@5": 97.422,
},
)
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -438,7 +474,7 @@ class ViT_L_32_Weights(WeightsEnum): ...@@ -438,7 +474,7 @@ class ViT_L_32_Weights(WeightsEnum):
class ViT_H_14_Weights(WeightsEnum): class ViT_H_14_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth", url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
transforms=partial( transforms=partial(
ImageClassification, ImageClassification,
...@@ -455,7 +491,25 @@ class ViT_H_14_Weights(WeightsEnum): ...@@ -455,7 +491,25 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5": 98.694, "acc@5": 98.694,
}, },
) )
DEFAULT = IMAGENET1K_SWAG_V1 IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 632045800,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.708,
"acc@5": 97.730,
},
)
DEFAULT = IMAGENET1K_SWAG_E2E_V1
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment