Unverified Commit 3fa24148 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add SWAG model weight that only the linear head is finetuned to ImageNet1K (#5793)

* Add SWAG model that only the linear classifier head is finetuned with frozen trunk weight

* Add accuracy from experiments

* Change name from SWAG_LC to SWAG_LINEAR

* Add comment on SWAG_LINEAR weight

* Remove the comment docs (moved to PR description), and add the PR url as recipe. Also change name of previous swag model to SWAG_E2E_V1
parent c399c3f4
...@@ -575,7 +575,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -575,7 +575,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 96.328, "acc@5": 96.328,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth", url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
...@@ -587,6 +587,19 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -587,6 +587,19 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 98.054, "acc@5": 98.054,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 83590140,
"acc@1": 83.976,
"acc@5": 97.244,
},
)
DEFAULT = IMAGENET1K_V2 DEFAULT = IMAGENET1K_V2
...@@ -613,7 +626,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -613,7 +626,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 96.498, "acc@5": 96.498,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth", url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
...@@ -625,11 +638,24 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -625,11 +638,24 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 98.362, "acc@5": 98.362,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 145046770,
"acc@1": 84.622,
"acc@5": 97.480,
},
)
DEFAULT = IMAGENET1K_V2 DEFAULT = IMAGENET1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_Y_128GF_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth", url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
...@@ -641,7 +667,20 @@ class RegNet_Y_128GF_Weights(WeightsEnum): ...@@ -641,7 +667,20 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5": 98.682, "acc@5": 98.682,
}, },
) )
DEFAULT = IMAGENET1K_SWAG_V1 IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 644812894,
"acc@1": 86.068,
"acc@5": 97.844,
},
)
DEFAULT = IMAGENET1K_SWAG_E2E_V1
class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum):
......
...@@ -349,7 +349,7 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -349,7 +349,7 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318, "acc@5": 95.318,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
transforms=partial( transforms=partial(
ImageClassification, ImageClassification,
...@@ -366,6 +366,24 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -366,6 +366,24 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 97.650, "acc@5": 97.650,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 81.886,
"acc@5": 96.180,
},
)
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -400,7 +418,7 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -400,7 +418,7 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 94.638, "acc@5": 94.638,
}, },
) )
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth", url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
transforms=partial( transforms=partial(
ImageClassification, ImageClassification,
...@@ -417,6 +435,24 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -417,6 +435,24 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 98.512, "acc@5": 98.512,
}, },
) )
IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.146,
"acc@5": 97.422,
},
)
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -438,7 +474,7 @@ class ViT_L_32_Weights(WeightsEnum): ...@@ -438,7 +474,7 @@ class ViT_L_32_Weights(WeightsEnum):
class ViT_H_14_Weights(WeightsEnum): class ViT_H_14_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights( IMAGENET1K_SWAG_E2E_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth", url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
transforms=partial( transforms=partial(
ImageClassification, ImageClassification,
...@@ -455,7 +491,25 @@ class ViT_H_14_Weights(WeightsEnum): ...@@ -455,7 +491,25 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5": 98.694, "acc@5": 98.694,
}, },
) )
DEFAULT = IMAGENET1K_SWAG_V1 IMAGENET1K_SWAG_LINEAR_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
transforms=partial(
ImageClassification,
crop_size=224,
resize_size=224,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 632045800,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.708,
"acc@5": 97.730,
},
)
DEFAULT = IMAGENET1K_SWAG_E2E_V1
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment