Unverified Commit 68f511eb authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

[ViT] Graduate ViT from prototype (#5173)

* graduate vit from prototype

* nit

* add vit to docs and hubconf

* ufmt

* re-correct ufmt

* again

* fix linter
parent d675c0c6
...@@ -40,6 +40,7 @@ architectures for image classification: ...@@ -40,6 +40,7 @@ architectures for image classification:
- `MNASNet`_ - `MNASNet`_
- `EfficientNet`_ - `EfficientNet`_
- `RegNet`_ - `RegNet`_
- `VisionTransformer`_
You can construct a model with random weights by calling its constructor: You can construct a model with random weights by calling its constructor:
...@@ -82,6 +83,10 @@ You can construct a model with random weights by calling its constructor: ...@@ -82,6 +83,10 @@ You can construct a model with random weights by calling its constructor:
regnet_x_8gf = models.regnet_x_8gf() regnet_x_8gf = models.regnet_x_8gf()
regnet_x_16gf = models.regnet_x_16gf() regnet_x_16gf = models.regnet_x_16gf()
regnet_x_32gf = models.regnet_x_32gf() regnet_x_32gf = models.regnet_x_32gf()
vit_b_16 = models.vit_b_16()
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``: These can be constructed by passing ``pretrained=True``:
...@@ -125,6 +130,10 @@ These can be constructed by passing ``pretrained=True``: ...@@ -125,6 +130,10 @@ These can be constructed by passing ``pretrained=True``:
regnet_x_8gf = models.regnet_x_8gf(pretrained=True) regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue) regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True) regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_32 = models.vit_b_32(pretrained=True)
vit_l_16 = models.vit_l_16(pretrained=True)
vit_l_32 = models.vit_l_32(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory. Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_HOME` environment variable. See This directory can be set using the `TORCH_HOME` environment variable. See
...@@ -233,6 +242,10 @@ regnet_y_3_2gf 78.948 94.576 ...@@ -233,6 +242,10 @@ regnet_y_3_2gf 78.948 94.576
regnet_y_8gf 80.032 95.048 regnet_y_8gf 80.032 95.048
regnet_y_16gf 80.424 95.240 regnet_y_16gf 80.424 95.240
regnet_y_32gf 80.878 95.340 regnet_y_32gf 80.878 95.340
vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
================================ ============= ============= ================================ ============= =============
...@@ -250,6 +263,7 @@ regnet_y_32gf 80.878 95.340 ...@@ -250,6 +263,7 @@ regnet_y_32gf 80.878 95.340
.. _MNASNet: https://arxiv.org/abs/1807.11626 .. _MNASNet: https://arxiv.org/abs/1807.11626
.. _EfficientNet: https://arxiv.org/abs/1905.11946 .. _EfficientNet: https://arxiv.org/abs/1905.11946
.. _RegNet: https://arxiv.org/abs/2003.13678 .. _RegNet: https://arxiv.org/abs/2003.13678
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
.. currentmodule:: torchvision.models .. currentmodule:: torchvision.models
...@@ -433,6 +447,18 @@ RegNet ...@@ -433,6 +447,18 @@ RegNet
regnet_x_16gf regnet_x_16gf
regnet_x_32gf regnet_x_32gf
VisionTransformer
-----------------
.. autosummary::
:toctree: generated/
:template: function.rst
vit_b_16
vit_b_32
vit_l_16
vit_l_32
Quantized Models Quantized Models
---------------- ----------------
......
# Optional list of dependencies required by the package # Optional list of dependencies required by the package
dependencies = ["torch"] dependencies = ["torch"]
# classification
from torchvision.models.alexnet import alexnet from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.efficientnet import ( from torchvision.models.efficientnet import (
...@@ -47,8 +46,6 @@ from torchvision.models.resnet import ( ...@@ -47,8 +46,6 @@ from torchvision.models.resnet import (
wide_resnet50_2, wide_resnet50_2,
wide_resnet101_2, wide_resnet101_2,
) )
# segmentation
from torchvision.models.segmentation import ( from torchvision.models.segmentation import (
fcn_resnet50, fcn_resnet50,
fcn_resnet101, fcn_resnet101,
...@@ -60,3 +57,9 @@ from torchvision.models.segmentation import ( ...@@ -60,3 +57,9 @@ from torchvision.models.segmentation import (
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.vision_transformer import (
vit_b_16,
vit_b_32,
vit_l_16,
vit_l_32,
)
...@@ -10,6 +10,7 @@ from .mnasnet import * ...@@ -10,6 +10,7 @@ from .mnasnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .efficientnet import * from .efficientnet import *
from .regnet import * from .regnet import *
from .vision_transformer import *
from . import detection from . import detection
from . import feature_extraction from . import feature_extraction
from . import optical_flow from . import optical_flow
......
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url
from ..utils import _log_api_usage_once
__all__ = [
"VisionTransformer",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
]
model_urls = {
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
}
class MLPBlock(nn.Sequential):
"""Transformer MLP block."""
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
def __init__(
self,
image_size: int,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
_log_api_usage_once(self)
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer
input_channels = 3
# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
seq_length = (image_size // patch_size) ** 2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)
self._init_weights()
def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)
nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, "Wrong image height!")
torch._assert(w == self.image_size, "Wrong image width!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x
def forward(self, x: torch.Tensor):
# Reshape and permute the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
def _vision_transformer(
arch: str,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
pretrained: bool,
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
model = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
**kwargs,
)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vision_transformer(
arch="vit_b_16",
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
pretrained=pretrained,
progress=progress,
**kwargs,
)
def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vision_transformer(
arch="vit_b_32",
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
pretrained=pretrained,
progress=progress,
**kwargs,
)
def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vision_transformer(
arch="vit_l_16",
patch_size=16,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
pretrained=pretrained,
progress=progress,
**kwargs,
)
def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _vision_transformer(
arch="vit_l_32",
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
pretrained=pretrained,
progress=progress,
**kwargs,
)
def interpolate_embeddings(
image_size: int,
patch_size: int,
model_state: "OrderedDict[str, torch.Tensor]",
interpolation_mode: str = "bicubic",
reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
"""This function helps interpolating positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
Args:
image_size (int): Image size of the new model.
patch_size (int): Patch size of the new model.
model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
reset_heads (bool): If true, not copying the state of heads. Default: False.
Returns:
OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
"""
# Shape of pos_embedding is (1, seq_length, hidden_dim)
pos_embedding = model_state["encoder.pos_embedding"]
n, seq_length, hidden_dim = pos_embedding.shape
if n != 1:
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
new_seq_length = (image_size // patch_size) ** 2 + 1
# Need to interpolate the weights for the position embedding.
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if new_seq_length != seq_length:
# The class token embedding shouldn't be interpolated so we split it up.
seq_length -= 1
new_seq_length -= 1
pos_embedding_token = pos_embedding[:, :1, :]
pos_embedding_img = pos_embedding[:, 1:, :]
# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
seq_length_1d = int(math.sqrt(seq_length))
torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
new_seq_length_1d = image_size // patch_size
# Perform interpolation.
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
new_pos_embedding_img = nn.functional.interpolate(
pos_embedding_img,
size=new_seq_length_1d,
mode=interpolation_mode,
align_corners=True,
)
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
model_state["encoder.pos_embedding"] = new_pos_embedding
if reset_heads:
model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
for k, v in model_state.items():
if not k.startswith("heads"):
model_state_copy[k] = v
model_state = model_state_copy
return model_state
...@@ -2,18 +2,13 @@ ...@@ -2,18 +2,13 @@
# https://github.com/google-research/vision_transformer # https://github.com/google-research/vision_transformer
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py # https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py
import math
from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, Optional from typing import Any, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...utils import _log_api_usage_once from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface from ._utils import handle_legacy_interface
...@@ -31,217 +26,6 @@ __all__ = [ ...@@ -31,217 +26,6 @@ __all__ = [
] ]
class MLPBlock(nn.Sequential):
"""Transformer MLP block."""
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
def __init__(
self,
image_size: int,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
_log_api_usage_once(self)
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer
input_channels = 3
# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
seq_length = (image_size // patch_size) ** 2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)
self._init_weights()
def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)
nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, "Wrong image height!")
torch._assert(w == self.image_size, "Wrong image width!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x
def forward(self, x: torch.Tensor):
# Reshaping and permuting the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
_COMMON_META = { _COMMON_META = {
"task": "image_classification", "task": "image_classification",
"architecture": "ViT", "architecture": "ViT",
...@@ -345,15 +129,6 @@ def _vision_transformer( ...@@ -345,15 +129,6 @@ def _vision_transformer(
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.ImageNet1K_V1)) @handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.ImageNet1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
weights = ViT_B_16_Weights.verify(weights) weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -370,15 +145,6 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru ...@@ -370,15 +145,6 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.ImageNet1K_V1)) @handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.ImageNet1K_V1))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
weights = ViT_B_32_Weights.verify(weights) weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -395,15 +161,6 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru ...@@ -395,15 +161,6 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.ImageNet1K_V1)) @handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.ImageNet1K_V1))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
weights = ViT_L_16_Weights.verify(weights) weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -420,15 +177,6 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru ...@@ -420,15 +177,6 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru
@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.ImageNet1K_V1)) @handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.ImageNet1K_V1))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (ViT_L_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
weights = ViT_L_32_Weights.verify(weights) weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer( return _vision_transformer(
...@@ -441,78 +189,3 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru ...@@ -441,78 +189,3 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
def interpolate_embeddings(
image_size: int,
patch_size: int,
model_state: "OrderedDict[str, torch.Tensor]",
interpolation_mode: str = "bicubic",
reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
"""This function helps interpolating positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
Args:
image_size (int): Image size of the new model.
patch_size (int): Patch size of the new model.
model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
reset_heads (bool): If true, not copying the state of heads. Default: False.
Returns:
OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
"""
# Shape of pos_embedding is (1, seq_length, hidden_dim)
pos_embedding = model_state["encoder.pos_embedding"]
n, seq_length, hidden_dim = pos_embedding.shape
if n != 1:
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
new_seq_length = (image_size // patch_size) ** 2 + 1
# Need to interpolate the weights for the position embedding.
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if new_seq_length != seq_length:
# The class token embedding shouldn't be interpolated so we split it up.
seq_length -= 1
new_seq_length -= 1
pos_embedding_token = pos_embedding[:, :1, :]
pos_embedding_img = pos_embedding[:, 1:, :]
# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
seq_length_1d = int(math.sqrt(seq_length))
torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
new_seq_length_1d = image_size // patch_size
# Perform interpolation.
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
new_pos_embedding_img = nn.functional.interpolate(
pos_embedding_img,
size=new_seq_length_1d,
mode=interpolation_mode,
align_corners=True,
)
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
model_state["encoder.pos_embedding"] = new_pos_embedding
if reset_heads:
model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
for k, v in model_state.items():
if not k.startswith("heads"):
model_state_copy[k] = v
model_state = model_state_copy
return model_state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment