"src/vscode:/vscode.git/clone" did not exist on "d70f8ee18b50c38f377a18a9fa8da0ae15b6426d"
Unverified Commit 03d11338 authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

Adding vit_h_14 architecture (#5210)

* adding vit_h_14

* prototype and docs

* bug fix

* adding curl check
parent abc6c778
......@@ -88,6 +88,7 @@ You can construct a model with random weights by calling its constructor:
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
vit_h_14 = models.vit_h_14()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
......@@ -460,6 +461,7 @@ VisionTransformer
vit_b_32
vit_l_16
vit_l_32
vit_h_14
Quantized Models
----------------
......
......@@ -63,4 +63,5 @@ from torchvision.models.vision_transformer import (
vit_b_32,
vit_l_16,
vit_l_32,
vit_h_14,
)
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -15,6 +15,7 @@ __all__ = [
"vit_b_32",
"vit_l_16",
"vit_l_32",
"vit_h_14",
]
model_urls = {
......@@ -260,6 +261,8 @@ def _vision_transformer(
)
if pretrained:
if arch not in model_urls:
raise ValueError(f"No checkpoint is available for model type '{arch}'!")
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
......@@ -354,6 +357,26 @@ def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
)
def vit_h_14(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_h_14 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
NOTE: Pretrained weights are not available for this model.
"""
return _vision_transformer(
arch="vit_h_14",
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
pretrained=pretrained,
progress=progress,
**kwargs,
)
def interpolate_embeddings(
image_size: int,
patch_size: int,
......
......@@ -19,10 +19,12 @@ __all__ = [
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"ViT_H_14_Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
"vit_h_14",
]
......@@ -99,6 +101,11 @@ class ViT_L_32_Weights(WeightsEnum):
default = ImageNet1K_V1
class ViT_H_14_Weights(WeightsEnum):
# Weights are not available yet.
pass
def _vision_transformer(
patch_size: int,
num_layers: int,
......@@ -192,3 +199,19 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
progress=progress,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", None))
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights = ViT_H_14_Weights.verify(weights)
return _vision_transformer(
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
weights=weights,
progress=progress,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment