Unverified Commit 7d868aa6 authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

[ViT] Adding conv_stem support (#5226)

* Adding conv_stem support

* fix lint

* bug fix

* address comments

* fix after merge

* adding back checking lines

* fix failing tests

* fix iignore

* add unittest & address comments

* fix memory issue

* address comments
parent 9fa8000d
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -8,6 +8,7 @@ import traceback ...@@ -8,6 +8,7 @@ import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any
import pytest import pytest
import torch import torch
...@@ -514,6 +515,35 @@ def test_generalizedrcnn_transform_repr(): ...@@ -514,6 +515,35 @@ def test_generalizedrcnn_transform_repr():
assert t.__repr__() == expected_string assert t.__repr__() == expected_string
test_vit_conv_stem_configs = [
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
]
def vitc_b_16(**kwargs: Any):
return models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
conv_stem_configs=test_vit_conv_stem_configs,
**kwargs,
)
@pytest.mark.parametrize("model_fn", [vitc_b_16])
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", get_models_from_module(models)) @pytest.mark.parametrize("model_fn", get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev): def test_classification_model(model_fn, dev):
......
import math import math
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, Optional from typing import Any, Callable, List, NamedTuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
__all__ = [ __all__ = [
...@@ -25,6 +26,14 @@ model_urls = { ...@@ -25,6 +26,14 @@ model_urls = {
} }
class ConvStemConfig(NamedTuple):
out_channels: int
kernel_size: int
stride: int
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
activation_layer: Callable[..., nn.Module] = nn.ReLU
class MLPBlock(nn.Sequential): class MLPBlock(nn.Sequential):
"""Transformer MLP block.""" """Transformer MLP block."""
...@@ -134,6 +143,7 @@ class VisionTransformer(nn.Module): ...@@ -134,6 +143,7 @@ class VisionTransformer(nn.Module):
num_classes: int = 1000, num_classes: int = 1000,
representation_size: Optional[int] = None, representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -148,11 +158,31 @@ class VisionTransformer(nn.Module): ...@@ -148,11 +158,31 @@ class VisionTransformer(nn.Module):
self.representation_size = representation_size self.representation_size = representation_size
self.norm_layer = norm_layer self.norm_layer = norm_layer
input_channels = 3 if conv_stem_configs is not None:
# As per https://arxiv.org/abs/2106.14881
# The conv_proj is a more efficient version of reshaping, permuting seq_proj = nn.Sequential()
# and projecting the input prev_channels = 3
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module(
f"conv_bn_relu_{i}",
ConvNormActivation(
in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size,
stride=conv_stem_layer_config.stride,
norm_layer=conv_stem_layer_config.norm_layer,
activation_layer=conv_stem_layer_config.activation_layer,
),
)
prev_channels = conv_stem_layer_config.out_channels
seq_proj.add_module(
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
)
self.conv_proj: nn.Module = seq_proj
else:
self.conv_proj = nn.Conv2d(
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
)
seq_length = (image_size // patch_size) ** 2 seq_length = (image_size // patch_size) ** 2
...@@ -184,9 +214,17 @@ class VisionTransformer(nn.Module): ...@@ -184,9 +214,17 @@ class VisionTransformer(nn.Module):
self._init_weights() self._init_weights()
def _init_weights(self): def _init_weights(self):
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] if isinstance(self.conv_proj, nn.Conv2d):
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) # Init the patchify stem
nn.init.zeros_(self.conv_proj.bias) fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias)
else:
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
nn.init.zeros_(self.conv_proj.conv_last.bias)
if hasattr(self.heads, "pre_logits"): if hasattr(self.heads, "pre_logits"):
fan_in = self.heads.pre_logits.in_features fan_in = self.heads.pre_logits.in_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment