Unverified Commit 47281bbf authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

Adding ViT to torchvision/models (#4594)



* [vit] Adding ViT to torchvision/models

* adding pre-logits layer + resolving comments

* Fix the model attribute bug

* Change version to arch

* fix failing unittests

* remove useless prints

* reduce input size to fix unittests

* Increase windows-cpu executor to 2xlarge

* Use `batch_first=True` and remove classifier

* Change resource_class back to xlarge

* Remove vit_h_14

* Remove vit_h_14 from __all__

* Move vision_transformer.py into prototype

* Fix formatting issue

* remove arch in builder

* Fix type err in model builder

* address comments and trigger unittests

* remove the prototype import in torchvision.models

* Adding vit back to models to trigger CircleCI test

* fix test_jit_forward_backward

* Move all to prototype.

* Adopt new helper methods and fix prototype tests.

* Remove unused import.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 29f38f17
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import random
from itertools import chain
from typing import Mapping, Sequence
import pytest
import torch
......@@ -89,7 +90,16 @@ class TestFxFeatureExtraction:
def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
exclude_nodes_filter = [
"getitem",
"floordiv",
"size",
"chunk",
"_assert",
"eq",
"dim",
"getattr",
]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
)
......@@ -144,7 +154,16 @@ class TestFxFeatureExtraction:
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
out = model(self.inp)
sum(o.mean() for o in out.values()).backward()
out_agg = 0
for node_out in out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()
def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
......@@ -176,7 +195,16 @@ class TestFxFeatureExtraction:
)
model = torch.jit.script(model)
fgn_out = model(self.inp)
sum(o.mean() for o in fgn_out.values()).backward()
out_agg = 0
for node_out in fgn_out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()
def test_train_eval(self):
class TestModel(torch.nn.Module):
......
......@@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev):
}
model_name = model_fn.__name__
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape")
model = model_fn(**kwargs)
......@@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev):
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
......
......@@ -122,8 +122,11 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
x = [x]
# compare with new model builder parameterized in the old fashion way
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
......
......@@ -10,6 +10,7 @@ from .resnet import *
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import *
from . import detection
from . import quantization
from . import segmentation
......
# References:
# https://github.com/google-research/vision_transformer
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
from torch import Tensor
from ._api import Weights
from ._utils import _deprecated_param, _deprecated_positional
__all__ = [
"VisionTransformer",
"VisionTransformer_B_16Weights",
"VisionTransformer_B_32Weights",
"VisionTransformer_L_16Weights",
"VisionTransformer_L_32Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
]
class MLPBlock(nn.Sequential):
"""Transformer MLP block."""
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
def __init__(
self,
image_size: int,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer
input_channels = 3
# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
seq_length = (image_size // patch_size) ** 2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)
self._init_weights()
def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)
nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
def forward(self, x: torch.Tensor):
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, "Wrong image height!")
torch._assert(w == self.image_size, "Wrong image width!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
# Expand the class token to the full batch.
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
class VisionTransformer_B_16Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_b_16
pass
class VisionTransformer_B_32Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_b_32
pass
class VisionTransformer_L_16Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_l_16
pass
class VisionTransformer_L_32Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_l_32
pass
def _vision_transformer(
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
weights: Optional[Weights],
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
model = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
**kwargs,
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
def vit_b_16(
weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_16Weights.verify(weights)
return _vision_transformer(
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights,
progress=progress,
**kwargs,
)
def vit_b_32(
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_32Weights.verify(weights)
return _vision_transformer(
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights,
progress=progress,
**kwargs,
)
def vit_l_16(
weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_16Weights.verify(weights)
return _vision_transformer(
patch_size=16,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights,
progress=progress,
**kwargs,
)
def vit_l_32(
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_32Weights.verify(weights)
return _vision_transformer(
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights,
progress=progress,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment