Unverified Commit 63576c9f authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Adding the huge vision transformer from SWAG (#5721)



* Add vit_b_16_swag

* Better handling idiom for image_size, edit test_extended_model to handle case where number of param differ from default due to different image size input

* Update the accuracy to the experiment result on torchvision model

* Fix typo missing underscore

* raise exception instead of torch._assert, add back publication year (accidentally deleted)

* Add license information on meta and readme

* Improve wording and fix typo for pretrained model license in readme

* Add vit_l_16 weight

* Update README.rst
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update the accuracy meta on vit_l_16_swag model to result from our experiment

* Add vit_h_14_swag model

* Add accuracy from experiments

* Add to vit_h_16 model to hubconf.py

* Add docs and expected pkl file for test

* Remove legacy compatibility for ViT_H_14 model
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Test vit_h_14 with smaller image_size to speedup the test
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d0c92dc9
......@@ -92,6 +92,7 @@ You can construct a model with random weights by calling its constructor:
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
vit_h_14 = models.vit_h_14()
convnext_tiny = models.convnext_tiny()
convnext_small = models.convnext_small()
convnext_base = models.convnext_base()
......@@ -213,6 +214,7 @@ vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
vit_h_14 88.552 98.694
convnext_tiny 82.520 96.146
convnext_small 83.616 96.650
convnext_base 84.062 96.870
......@@ -434,6 +436,7 @@ VisionTransformer
vit_b_32
vit_l_16
vit_l_32
vit_h_14
ConvNeXt
--------
......
......@@ -67,4 +67,5 @@ from torchvision.models.vision_transformer import (
vit_b_32,
vit_l_16,
vit_l_32,
vit_h_14,
)
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -280,6 +280,10 @@ _model_params = {
"rpn_pre_nms_top_n_test": 1000,
"rpn_post_nms_top_n_test": 1000,
},
"vit_h_14": {
"image_size": 56,
"input_shape": (1, 3, 56, 56),
},
}
# speeding up slow models:
slow_models = [
......
......@@ -20,10 +20,12 @@ __all__ = [
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"ViT_H_14_Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
"vit_h_14",
]
......@@ -435,6 +437,27 @@ class ViT_L_32_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
class ViT_H_14_Weights(WeightsEnum):
IMAGENET1K_SWAG_V1 = Weights(
url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
transforms=partial(
ImageClassification,
crop_size=518,
resize_size=518,
interpolation=InterpolationMode.BICUBIC,
),
meta={
**_COMMON_SWAG_META,
"num_params": 633470440,
"size": (518, 518),
"min_size": (518, 518),
"acc@1": 88.552,
"acc@5": 98.694,
},
)
DEFAULT = IMAGENET1K_SWAG_V1
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
......@@ -531,6 +554,29 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
)
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_h_14 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_H_14_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
"""
weights = ViT_H_14_Weights.verify(weights)
return _vision_transformer(
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
weights=weights,
progress=progress,
**kwargs,
)
def interpolate_embeddings(
image_size: int,
patch_size: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment