Unverified Commit 47281bbf authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

Adding ViT to torchvision/models (#4594)



* [vit] Adding ViT to torchvision/models

* adding pre-logits layer + resolving comments

* Fix the model attribute bug

* Change version to arch

* fix failing unittests

* remove useless prints

* reduce input size to fix unittests

* Increase windows-cpu executor to 2xlarge

* Use `batch_first=True` and remove classifier

* Change resource_class back to xlarge

* Remove vit_h_14

* Remove vit_h_14 from __all__

* Move vision_transformer.py into prototype

* Fix formatting issue

* remove arch in builder

* Fix type err in model builder

* address comments and trigger unittests

* remove the prototype import in torchvision.models

* Adding vit back to models to trigger CircleCI test

* fix test_jit_forward_backward

* Move all to prototype.

* Adopt new helper methods and fix prototype tests.

* Remove unused import.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 29f38f17
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import random import random
from itertools import chain from itertools import chain
from typing import Mapping, Sequence
import pytest import pytest
import torch import torch
...@@ -89,7 +90,16 @@ class TestFxFeatureExtraction: ...@@ -89,7 +90,16 @@ class TestFxFeatureExtraction:
def _get_return_nodes(self, model): def _get_return_nodes(self, model):
set_rng_seed(0) set_rng_seed(0)
exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"] exclude_nodes_filter = [
"getitem",
"floordiv",
"size",
"chunk",
"_assert",
"eq",
"dim",
"getattr",
]
train_nodes, eval_nodes = get_graph_node_names( train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
) )
...@@ -144,7 +154,16 @@ class TestFxFeatureExtraction: ...@@ -144,7 +154,16 @@ class TestFxFeatureExtraction:
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
) )
out = model(self.inp) out = model(self.inp)
sum(o.mean() for o in out.values()).backward() out_agg = 0
for node_out in out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()
def test_feature_extraction_methods_equivalence(self): def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval() model = models.resnet18(**self.model_defaults).eval()
...@@ -176,7 +195,16 @@ class TestFxFeatureExtraction: ...@@ -176,7 +195,16 @@ class TestFxFeatureExtraction:
) )
model = torch.jit.script(model) model = torch.jit.script(model)
fgn_out = model(self.inp) fgn_out = model(self.inp)
sum(o.mean() for o in fgn_out.values()).backward() out_agg = 0
for node_out in fgn_out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.mean()
out_agg.backward()
def test_train_eval(self): def test_train_eval(self):
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
......
...@@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev): ...@@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev):
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape") input_shape = kwargs.pop("input_shape")
model = model_fn(**kwargs) model = model_fn(**kwargs)
...@@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev): ...@@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev):
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1) _assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50 assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x) _check_fx_compatible(model, x)
......
...@@ -122,8 +122,11 @@ def test_old_vs_new_factory(model_fn, module_name, dev): ...@@ -122,8 +122,11 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
x = [x] x = [x]
# compare with new model builder parameterized in the old fashion way # compare with new model builder parameterized in the old fashion way
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev) model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False) torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False)
......
...@@ -10,6 +10,7 @@ from .resnet import * ...@@ -10,6 +10,7 @@ from .resnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .squeezenet import * from .squeezenet import *
from .vgg import * from .vgg import *
from .vision_transformer import *
from . import detection from . import detection
from . import quantization from . import quantization
from . import segmentation from . import segmentation
......
# References:
# https://github.com/google-research/vision_transformer
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
from torch import Tensor
from ._api import Weights
from ._utils import _deprecated_param, _deprecated_positional
__all__ = [
"VisionTransformer",
"VisionTransformer_B_16Weights",
"VisionTransformer_B_32Weights",
"VisionTransformer_L_16Weights",
"VisionTransformer_L_32Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
]
class MLPBlock(nn.Sequential):
"""Transformer MLP block."""
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU()
self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6)
nn.init.normal_(self.linear_2.bias, std=1e-6)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""Vision Transformer as per https://arxiv.org/abs/2010.11929."""
def __init__(
self,
image_size: int,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer
input_channels = 3
# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
seq_length = (image_size // patch_size) ** 2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length
heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
if representation_size is None:
heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
else:
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
heads_layers["act"] = nn.Tanh()
heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers)
self._init_weights()
def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias)
nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
def forward(self, x: torch.Tensor):
n, c, h, w = x.shape
p = self.patch_size
torch._assert(h == self.image_size, "Wrong image height!")
torch._assert(w == self.image_size, "Wrong image width!")
n_h = h // p
n_w = w // p
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
# Expand the class token to the full batch.
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x)
# Classifier "token" as used by standard language architectures
x = x[:, 0]
x = self.heads(x)
return x
class VisionTransformer_B_16Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_b_16
pass
class VisionTransformer_B_32Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_b_32
pass
class VisionTransformer_L_16Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_l_16
pass
class VisionTransformer_L_32Weights(Weights):
# If a default model is added here the corresponding changes need to be done in vit_l_32
pass
def _vision_transformer(
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
weights: Optional[Weights],
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
model = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
**kwargs,
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
def vit_b_16(
weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_b_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_16Weights.verify(weights)
return _vision_transformer(
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights,
progress=progress,
**kwargs,
)
def vit_b_32(
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_b_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_B_32Weights.verify(weights)
return _vision_transformer(
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights,
progress=progress,
**kwargs,
)
def vit_l_16(
weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_l_16 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_16Weights.verify(weights)
return _vision_transformer(
patch_size=16,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights,
progress=progress,
**kwargs,
)
def vit_l_32(
weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any
) -> VisionTransformer:
"""
Constructs a vit_l_32 architecture from
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_.
Args:
weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet.
Default: None.
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True.
"""
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", None)
weights = VisionTransformer_L_32Weights.verify(weights)
return _vision_transformer(
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights,
progress=progress,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment