"vscode:/vscode.git/clone" did not exist on "38f89e595b56c0bbea6e993c3c3705ca502bf884"
Unverified Commit 781b0f9c authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add SWAG Vision Transformer Weight (#5714)



* Add vit_b_16_swag

* Better handling idiom for image_size, edit test_extended_model to handle case where number of param differ from default due to different image size input

* Update the accuracy to the experiment result on torchvision model

* Fix typo missing underscore

* raise exception instead of torch._assert, add back publication year (accidentally deleted)

* Add license information on meta and readme

* Improve wording and fix typo for pretrained model license in readme

* Add vit_l_16 weight

* Update README.rst
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update the accuracy meta on vit_l_16_swag model to result from our experiment
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 3925946f
......@@ -185,3 +185,10 @@ Disclaimer on Datasets
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.
If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!
Pre-trained Model License
=========================
The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.
More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE <https://github.com/facebookresearch/SWAG/blob/main/LICENSE>`_ for additional details.
......@@ -115,7 +115,8 @@ def test_schema_meta_validation(model_fn):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
incorrect_params.append(w)
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
if not w.name.isupper():
bad_names.append(w)
......
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, NamedTuple, Optional
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
import torch
import torch.nn as nn
......@@ -284,10 +284,21 @@ def _vision_transformer(
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if isinstance(weights.meta["size"], int):
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
elif isinstance(weights.meta["size"], Sequence):
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
raise ValueError(
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
)
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
else:
raise ValueError(
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
)
image_size = kwargs.pop("image_size", 224)
model = VisionTransformer(
image_size=image_size,
......@@ -313,6 +324,14 @@ _COMMON_META = {
"interpolation": InterpolationMode.BILINEAR,
}
_COMMON_SWAG_META = {
**_COMMON_META,
"publication_year": 2022,
"recipe": "https://github.com/facebookresearch/SWAG",
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
"interpolation": InterpolationMode.BICUBIC,
}
class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
......@@ -328,6 +347,23 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318,
},
)
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
transforms=partial(
ImageClassification,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 86859496,
"size": (384, 384),
"min_size": (384, 384),
"acc@1": 85.304,
"acc@5": 97.650,
},
)
DEFAULT = IMAGENET1K_V1
......@@ -362,6 +398,23 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 94.638,
},
)
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
transforms=partial(
ImageClassification,
crop_size=512,
resize_size=512,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 305174504,
"size": (512, 512),
"min_size": (512, 512),
"acc@1": 88.064,
"acc@5": 98.512,
},
)
DEFAULT = IMAGENET1K_V1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment