Unverified Commit 7d868aa6 authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

[ViT] Adding conv_stem support (#5226)

* Adding conv_stem support

* fix lint

* bug fix

* address comments

* fix after merge

* adding back checking lines

* fix failing tests

* fix iignore

* add unittest & address comments

* fix memory issue

* address comments
parent 9fa8000d
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -8,6 +8,7 @@ import traceback
import warnings
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any
import pytest
import torch
......@@ -514,6 +515,35 @@ def test_generalizedrcnn_transform_repr():
assert t.__repr__() == expected_string
test_vit_conv_stem_configs = [
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
]
def vitc_b_16(**kwargs: Any):
return models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
conv_stem_configs=test_vit_conv_stem_configs,
**kwargs,
)
@pytest.mark.parametrize("model_fn", [vitc_b_16])
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev):
......
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Optional
from typing import Any, Callable, List, NamedTuple, Optional
import torch
import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..utils import _log_api_usage_once
__all__ = [
......@@ -25,6 +26,14 @@ model_urls = {
}
class ConvStemConfig(NamedTuple):
out_channels: int
kernel_size: int
stride: int
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
activation_layer: Callable[..., nn.Module] = nn.ReLU
class MLPBlock(nn.Sequential):
"""Transformer MLP block."""
......@@ -134,6 +143,7 @@ class VisionTransformer(nn.Module):
num_classes: int = 1000,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
super().__init__()
_log_api_usage_once(self)
......@@ -148,11 +158,31 @@ class VisionTransformer(nn.Module):
self.representation_size = representation_size
self.norm_layer = norm_layer
input_channels = 3
# The conv_proj is a more efficient version of reshaping, permuting
# and projecting the input
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
if conv_stem_configs is not None:
# As per https://arxiv.org/abs/2106.14881
seq_proj = nn.Sequential()
prev_channels = 3
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module(
f"conv_bn_relu_{i}",
ConvNormActivation(
in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size,
stride=conv_stem_layer_config.stride,
norm_layer=conv_stem_layer_config.norm_layer,
activation_layer=conv_stem_layer_config.activation_layer,
),
)
prev_channels = conv_stem_layer_config.out_channels
seq_proj.add_module(
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
)
self.conv_proj: nn.Module = seq_proj
else:
self.conv_proj = nn.Conv2d(
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
)
seq_length = (image_size // patch_size) ** 2
......@@ -184,9 +214,17 @@ class VisionTransformer(nn.Module):
self._init_weights()
def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
else:
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
nn.init.zeros_(self.conv_proj.conv_last.bias)
if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment