"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d688af19e5ce92c1395820a89e3f3b635eacc2ba"
Unverified Commit 09178705 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve vision models (#17731)



* Improve vision models

* Add a lot of improvements

* Remove to_2tuple from swin tests

* Fix TF Swin

* Fix more tests

* Fix copies

* Improve more models

* Fix ViTMAE test

* Add channel check for TF models

* Add proper channel check for TF models

* Apply suggestion from code review

* Apply suggestions from code review

* Add channel check for Flax models, apply suggestion

* Fix bug

* Add tests for greyscale images

* Add test for interpolation of pos encodigns
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 893ab124
...@@ -91,17 +91,7 @@ class BeitModelOutputWithPooling(BaseModelOutputWithPooling): ...@@ -91,17 +91,7 @@ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
""" """
# Inspired by def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
...@@ -112,16 +102,16 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -112,16 +102,16 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
argument. argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
class DropPath(nn.Module): class BeitDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None: def __init__(self, drop_prob: Optional[float] = None) -> None:
...@@ -151,12 +141,7 @@ class BeitEmbeddings(nn.Module): ...@@ -151,12 +141,7 @@ class BeitEmbeddings(nn.Module):
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
else: else:
self.mask_token = None self.mask_token = None
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = BeitPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings: if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
...@@ -184,38 +169,43 @@ class BeitEmbeddings(nn.Module): ...@@ -184,38 +169,43 @@ class BeitEmbeddings(nn.Module):
return embeddings return embeddings
# Based on timm implementation, which can be found here: class BeitPatchEmbeddings(nn.Module):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__( def __init__(self, config):
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
) -> None:
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.patch_shape = patch_shape self.patch_shape = patch_shape
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
) )
x = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return embeddings
class BeitSelfAttention(nn.Module): class BeitSelfAttention(nn.Module):
...@@ -393,7 +383,7 @@ class BeitLayer(nn.Module): ...@@ -393,7 +383,7 @@ class BeitLayer(nn.Module):
self.intermediate = BeitIntermediate(config) self.intermediate = BeitIntermediate(config)
self.output = BeitOutput(config) self.output = BeitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
init_values = config.layer_scale_init_value init_values = config.layer_scale_init_value
......
...@@ -171,6 +171,7 @@ class FlaxBeitPatchEmbeddings(nn.Module): ...@@ -171,6 +171,7 @@ class FlaxBeitPatchEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
self.num_channels = self.config.num_channels
image_size = self.config.image_size image_size = self.config.image_size
patch_size = self.config.patch_size patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size) num_patches = (image_size // patch_size) * (image_size // patch_size)
...@@ -187,6 +188,11 @@ class FlaxBeitPatchEmbeddings(nn.Module): ...@@ -187,6 +188,11 @@ class FlaxBeitPatchEmbeddings(nn.Module):
) )
def __call__(self, pixel_values): def __call__(self, pixel_values):
num_channels = pixel_values.shape[-1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
batch_size, _, _, channels = embeddings.shape batch_size, _, _, channels = embeddings.shape
return jnp.reshape(embeddings, (batch_size, -1, channels)) return jnp.reshape(embeddings, (batch_size, -1, channels))
...@@ -603,7 +609,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): ...@@ -603,7 +609,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None: if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3) input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
......
...@@ -53,36 +53,41 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -53,36 +53,41 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Stochastic depth implementation # Copied from transformers.models.beit.modeling_beit.drop_path
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py def drop_path(input, drop_prob: float = 0.0, training: bool = False):
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion: Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
class ConvNextDropPath(nn.Module): class ConvNextDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class ConvNextLayerNorm(nn.Module): class ConvNextLayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
...@@ -122,8 +127,14 @@ class ConvNextEmbeddings(nn.Module): ...@@ -122,8 +127,14 @@ class ConvNextEmbeddings(nn.Module):
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
) )
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first") self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
self.num_channels = config.num_channels
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings) embeddings = self.layernorm(embeddings)
return embeddings return embeddings
......
...@@ -20,6 +20,8 @@ from typing import Dict, Optional, Tuple, Union ...@@ -20,6 +20,8 @@ from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from transformers import shape_list
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
...@@ -77,11 +79,18 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer): ...@@ -77,11 +79,18 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
bias_initializer="zeros", bias_initializer="zeros",
) )
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm") self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels
def call(self, pixel_values): def call(self, pixel_values):
if isinstance(pixel_values, dict): if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"] pixel_values = pixel_values["pixel_values"]
num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`. # So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels) # shape = (batch_size, in_height, in_width, in_channels=num_channels)
......
...@@ -78,36 +78,41 @@ class BaseModelOutputWithCLSToken(ModelOutput): ...@@ -78,36 +78,41 @@ class BaseModelOutputWithCLSToken(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.convnext.modeling_convnext.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False): def drop_path(input, drop_prob: float = 0.0, training: bool = False):
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion: Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath # Copied from transformers.models.beit.modeling_beit.BeitDropPath
class CvtDropPath(nn.Module): class CvtDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class CvtEmbeddings(nn.Module): class CvtEmbeddings(nn.Module):
""" """
......
...@@ -91,18 +91,8 @@ class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling): ...@@ -91,18 +91,8 @@ class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
""" """
# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
# Copied from transformers.models.beit.modeling_beit.drop_path # Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
...@@ -113,17 +103,17 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -113,17 +103,17 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
argument. argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
# Copied from transformers.models.beit.modeling_beit.DropPath # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
class DropPath(nn.Module): class Data2VecVisionDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None: def __init__(self, drop_prob: Optional[float] = None) -> None:
...@@ -137,8 +127,6 @@ class DropPath(nn.Module): ...@@ -137,8 +127,6 @@ class DropPath(nn.Module):
return "p={}".format(self.drop_prob) return "p={}".format(self.drop_prob)
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision # Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
class Data2VecVisionEmbeddings(nn.Module): class Data2VecVisionEmbeddings(nn.Module):
""" """
...@@ -154,12 +142,7 @@ class Data2VecVisionEmbeddings(nn.Module): ...@@ -154,12 +142,7 @@ class Data2VecVisionEmbeddings(nn.Module):
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
else: else:
self.mask_token = None self.mask_token = None
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings: if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
...@@ -187,39 +170,44 @@ class Data2VecVisionEmbeddings(nn.Module): ...@@ -187,39 +170,44 @@ class Data2VecVisionEmbeddings(nn.Module):
return embeddings return embeddings
# Based on timm implementation, which can be found here: # Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py class Data2VecVisionPatchEmbeddings(nn.Module):
# Copied from transformers.models.beit.modeling_beit.PatchEmbeddings
class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__( def __init__(self, config):
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
) -> None:
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.patch_shape = patch_shape self.patch_shape = patch_shape
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
) )
x = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return embeddings
# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision # Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
...@@ -405,7 +393,7 @@ class Data2VecVisionLayer(nn.Module): ...@@ -405,7 +393,7 @@ class Data2VecVisionLayer(nn.Module):
self.intermediate = Data2VecVisionIntermediate(config) self.intermediate = Data2VecVisionIntermediate(config)
self.output = Data2VecVisionOutput(config) self.output = Data2VecVisionOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
init_values = config.layer_scale_init_value init_values = config.layer_scale_init_value
......
...@@ -100,7 +100,7 @@ class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling): ...@@ -100,7 +100,7 @@ class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
class TFDropPath(tf.keras.layers.Layer): class TFData2VecVisionDropPath(tf.keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
References: References:
(1) github.com:rwightman/pytorch-image-models (1) github.com:rwightman/pytorch-image-models
...@@ -120,8 +120,6 @@ class TFDropPath(tf.keras.layers.Layer): ...@@ -120,8 +120,6 @@ class TFDropPath(tf.keras.layers.Layer):
return x return x
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFData2VecVisionEmbeddings(tf.keras.layers.Layer): class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
""" """
Construct the CLS token, position and patch embeddings. Optionally, also the mask token. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
...@@ -132,9 +130,7 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer): ...@@ -132,9 +130,7 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.patch_embeddings = TFPatchEmbeddings( self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings")
config=config, image_size=config.image_size, patch_size=config.patch_size, name="patch_embeddings"
)
self.num_patches = self.patch_embeddings.num_patches self.num_patches = self.patch_embeddings.num_patches
self.config = config self.config = config
...@@ -192,40 +188,32 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer): ...@@ -192,40 +188,32 @@ class TFData2VecVisionEmbeddings(tf.keras.layers.Layer):
return embeddings return embeddings
# Based on timm implementation, which can be found here: class TFData2VecVisionPatchEmbeddings(tf.keras.layers.Layer):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFPatchEmbeddings(tf.keras.layers.Layer):
""" """
Image to Patch Embedding. Image to Patch Embedding.
""" """
def __init__(self, config: Data2VecVisionConfig, image_size: int = 224, patch_size: int = 16, **kwargs): def __init__(self, config: Data2VecVisionConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
image_size = ( image_size, patch_size = config.image_size, config.patch_size
config.image_size num_channels, hidden_size = config.num_channels, config.hidden_size
if isinstance(config.image_size, collections.abc.Iterable)
else (config.image_size, config.image_size) image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
patch_size = (
config.patch_size
if isinstance(config.patch_size, collections.abc.Iterable)
else (config.patch_size, config.patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_patches = num_patches self.num_patches = num_patches
self.patch_shape = patch_shape self.patch_shape = patch_shape
self.num_channels = config.num_channels self.num_channels = num_channels
self.embed_dim = config.hidden_size
self.projection = tf.keras.layers.Conv2D( self.projection = tf.keras.layers.Conv2D(
filters=self.embed_dim, filters=hidden_size,
kernel_size=self.patch_size, kernel_size=patch_size,
strides=self.patch_size, strides=patch_size,
padding="valid", padding="valid",
data_format="channels_last", data_format="channels_last",
kernel_initializer="glorot_uniform", # following torch.nn.Linear kernel_initializer="glorot_uniform", # following torch.nn.Linear
...@@ -235,7 +223,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer): ...@@ -235,7 +223,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values) batch_size, num_channels, height, width = shape_list(pixel_values)
if getattr(height, "numpy", None) and getattr(width, "numpy", None): if tf.executing_eagerly():
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration."
)
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f"Input image size ({height}*{width}) doesn't match model"
...@@ -465,7 +458,7 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer): ...@@ -465,7 +458,7 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer):
# Using `layers.Activation` instead of `tf.identity` to better control `training` # Using `layers.Activation` instead of `tf.identity` to better control `training`
# behaviour. # behaviour.
self.drop_path = ( self.drop_path = (
TFDropPath(drop_path_rate, name="drop_path") TFData2VecVisionDropPath(drop_path_rate, name="drop_path")
if drop_path_rate > 0.0 if drop_path_rate > 0.0
else tf.keras.layers.Activation("linear", name="drop_path") else tf.keras.layers.Activation("linear", name="drop_path")
) )
......
...@@ -61,21 +61,9 @@ DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -61,21 +61,9 @@ DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class DeiTEmbeddings(nn.Module): class DeiTEmbeddings(nn.Module):
""" """
Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token. Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
""" """
def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None: def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
...@@ -84,22 +72,17 @@ class DeiTEmbeddings(nn.Module): ...@@ -84,22 +72,17 @@ class DeiTEmbeddings(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = DeiTPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_length, _ = embeddings.size()
if bool_masked_pos is not None: if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens # replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
...@@ -112,32 +95,34 @@ class DeiTEmbeddings(nn.Module): ...@@ -112,32 +95,34 @@ class DeiTEmbeddings(nn.Module):
return embeddings return embeddings
class PatchEmbeddings(nn.Module): class DeiTPatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__( def __init__(self, config):
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
...@@ -483,7 +468,7 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -483,7 +468,7 @@ class DeiTModel(DeiTPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self) -> PatchEmbeddings: def get_input_embeddings(self) -> DeiTPatchEmbeddings:
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -570,8 +555,8 @@ class DeiTPooler(nn.Module): ...@@ -570,8 +555,8 @@ class DeiTPooler(nn.Module):
@add_start_docstrings( @add_start_docstrings(
"DeiT Model with a decoder on top for masked image modeling, as proposed in `SimMIM" "DeiT Model with a decoder on top for masked image modeling, as proposed in"
" <https://arxiv.org/abs/2111.09886>`__.", " [SimMIM](https://arxiv.org/abs/2111.09886).",
DEIT_START_DOCSTRING, DEIT_START_DOCSTRING,
) )
class DeiTForMaskedImageModeling(DeiTPreTrainedModel): class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
...@@ -581,7 +566,11 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -581,7 +566,11 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True) self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1), nn.Conv2d(
in_channels=config.hidden_size,
out_channels=config.encoder_stride**2 * config.num_channels,
kernel_size=1,
),
nn.PixelShuffle(config.encoder_stride), nn.PixelShuffle(config.encoder_stride),
) )
......
...@@ -65,13 +65,6 @@ DPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -65,13 +65,6 @@ DPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
class DPTViTEmbeddings(nn.Module): class DPTViTEmbeddings(nn.Module):
""" """
Construct the CLS token, position and patch embeddings. Construct the CLS token, position and patch embeddings.
...@@ -82,12 +75,7 @@ class DPTViTEmbeddings(nn.Module): ...@@ -82,12 +75,7 @@ class DPTViTEmbeddings(nn.Module):
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = DPTViTPatchEmbeddings( self.patch_embeddings = DPTViTPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
...@@ -138,19 +126,27 @@ class DPTViTPatchEmbeddings(nn.Module): ...@@ -138,19 +126,27 @@ class DPTViTPatchEmbeddings(nn.Module):
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(self, config):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings return embeddings
......
...@@ -54,21 +54,23 @@ GLPN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -54,21 +54,23 @@ GLPN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.segformer.modeling_segformer.drop_path # Copied from transformers.models.segformer.modeling_segformer.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False): def drop_path(input, drop_prob: float = 0.0, training: bool = False):
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion: Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
...@@ -76,13 +78,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): ...@@ -76,13 +78,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class GLPNDropPath(nn.Module): class GLPNDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
# Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings # Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
class GLPNOverlapPatchEmbeddings(nn.Module): class GLPNOverlapPatchEmbeddings(nn.Module):
......
...@@ -126,8 +126,14 @@ class LevitPatchEmbeddings(nn.Module): ...@@ -126,8 +126,14 @@ class LevitPatchEmbeddings(nn.Module):
self.embedding_layer_4 = LevitConvEmbeddings( self.embedding_layer_4 = LevitConvEmbeddings(
config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
) )
self.num_channels = config.num_channels
def forward(self, pixel_values): def forward(self, pixel_values):
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.embedding_layer_1(pixel_values) embeddings = self.embedding_layer_1(pixel_values)
embeddings = self.activation_layer_1(embeddings) embeddings = self.activation_layer_1(embeddings)
embeddings = self.embedding_layer_2(embeddings) embeddings = self.embedding_layer_2(embeddings)
......
...@@ -471,13 +471,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = ...@@ -471,13 +471,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
return loss / height_and_width return loss / height_and_width
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Copied from transformers.models.swin.modeling_swin.window_partition # Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size): def window_partition(input_feature, window_size):
""" """
...@@ -506,15 +499,21 @@ def window_reverse(windows, window_size, height, width): ...@@ -506,15 +499,21 @@ def window_reverse(windows, window_size, height, width):
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return input return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = input.new_empty(shape).bernoulli_(keep_prob) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
if keep_prob > 0.0 and scale_by_keep: random_tensor.floor_() # binarize
random_tensor.div_(keep_prob) output = input.div(keep_prob) * random_tensor
return input * random_tensor return output
class MaskFormerSwinEmbeddings(nn.Module): class MaskFormerSwinEmbeddings(nn.Module):
...@@ -525,12 +524,7 @@ class MaskFormerSwinEmbeddings(nn.Module): ...@@ -525,12 +524,7 @@ class MaskFormerSwinEmbeddings(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.patch_embeddings = MaskFormerSwinPatchEmbeddings( self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.embed_dim,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size self.patch_grid = self.patch_embeddings.grid_size
...@@ -559,17 +553,21 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): ...@@ -559,17 +553,21 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
Image to Patch Embedding, including padding. Image to Patch Embedding, including padding.
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(self, config):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width): def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0: if width % self.patch_size[1] != 0:
...@@ -581,7 +579,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): ...@@ -581,7 +579,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
return pixel_values return pixel_values
def forward(self, pixel_values): def forward(self, pixel_values):
_, _, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
...@@ -649,13 +651,15 @@ class MaskFormerSwinPatchMerging(nn.Module): ...@@ -649,13 +651,15 @@ class MaskFormerSwinPatchMerging(nn.Module):
class MaskFormerSwinDropPath(nn.Module): class MaskFormerSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None, scale_by_keep=True): def __init__(self, drop_prob: Optional[float] = None) -> None:
super(MaskFormerSwinDropPath, self).__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, input): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(input, self.drop_prob, self.training, self.scale_by_keep) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
...@@ -670,7 +674,10 @@ class MaskFormerSwinSelfAttention(nn.Module): ...@@ -670,7 +674,10 @@ class MaskFormerSwinSelfAttention(nn.Module):
self.num_attention_heads = num_heads self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads) self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = to_2tuple(config.window_size) window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
......
...@@ -50,40 +50,41 @@ POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -50,40 +50,41 @@ POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.vit.modeling_vit.to_2tuple # Copied from transformers.models.beit.modeling_beit.drop_path
def to_2tuple(x): def drop_path(input, drop_prob: float = 0.0, training: bool = False):
if isinstance(x, collections.abc.Iterable): """
return x Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
return (x, x)
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
def drop_path(x, drop_prob: float = 0.0, training: bool = False): See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is argument.
misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion:
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->PoolFormer
class PoolFormerDropPath(nn.Module): class PoolFormerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class PoolFormerEmbeddings(nn.Module): class PoolFormerEmbeddings(nn.Module):
""" """
...@@ -92,17 +93,17 @@ class PoolFormerEmbeddings(nn.Module): ...@@ -92,17 +93,17 @@ class PoolFormerEmbeddings(nn.Module):
def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None): def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None):
super().__init__() super().__init__()
patch_size = to_2tuple(patch_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
stride = to_2tuple(stride) stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
padding = to_2tuple(padding) padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding)
self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity() self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity()
def forward(self, pixel_values): def forward(self, pixel_values):
x = self.projection(pixel_values) embeddings = self.projection(pixel_values)
x = self.norm(x) embeddings = self.norm(embeddings)
return x return embeddings
class PoolFormerGroupNorm(nn.GroupNorm): class PoolFormerGroupNorm(nn.GroupNorm):
......
...@@ -93,9 +93,15 @@ class RegNetEmbeddings(nn.Module): ...@@ -93,9 +93,15 @@ class RegNetEmbeddings(nn.Module):
self.embedder = RegNetConvLayer( self.embedder = RegNetConvLayer(
config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
) )
self.num_channels = config.num_channels
def forward(self, hidden_state): def forward(self, pixel_values):
hidden_state = self.embedder(hidden_state) num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = self.embedder(pixel_values)
return hidden_state return hidden_state
......
...@@ -81,9 +81,15 @@ class ResNetEmbeddings(nn.Module): ...@@ -81,9 +81,15 @@ class ResNetEmbeddings(nn.Module):
config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
) )
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.num_channels = config.num_channels
def forward(self, input: Tensor) -> Tensor: def forward(self, pixel_values: Tensor) -> Tensor:
embedding = self.embedder(input) num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embedding = self.embedder(pixel_values)
embedding = self.pooler(embedding) embedding = self.pooler(embedding)
return embedding return embedding
...@@ -107,7 +113,7 @@ class ResNetShortCut(nn.Module): ...@@ -107,7 +113,7 @@ class ResNetShortCut(nn.Module):
class ResNetBasicLayer(nn.Module): class ResNetBasicLayer(nn.Module):
""" """
A classic ResNet's residual layer composed by a two `3x3` convolutions. A classic ResNet's residual layer composed by two `3x3` convolutions.
""" """
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"): def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
...@@ -133,10 +139,10 @@ class ResNetBasicLayer(nn.Module): ...@@ -133,10 +139,10 @@ class ResNetBasicLayer(nn.Module):
class ResNetBottleNeckLayer(nn.Module): class ResNetBottleNeckLayer(nn.Module):
""" """
A classic ResNet's bottleneck layer composed by a three `3x3` convolutions. A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
convolution faster. The last `1x1` convolution remap the reduced features to `out_channels`. convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
""" """
def __init__( def __init__(
......
...@@ -86,21 +86,23 @@ class SegFormerImageClassifierOutput(ImageClassifierOutput): ...@@ -86,21 +86,23 @@ class SegFormerImageClassifierOutput(ImageClassifierOutput):
# Copied from transformers.models.convnext.modeling_convnext.drop_path # Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True): def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion: Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
...@@ -108,13 +110,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=T ...@@ -108,13 +110,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=T
class SegformerDropPath(nn.Module): class SegformerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class SegformerOverlapPatchEmbeddings(nn.Module): class SegformerOverlapPatchEmbeddings(nn.Module):
"""Construct the overlapping patch embeddings.""" """Construct the overlapping patch embeddings."""
......
...@@ -59,7 +59,7 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -59,7 +59,7 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Swin models at https://huggingface.co/models?filter=swin # See all Swin models at https://huggingface.co/models?filter=swin
] ]
# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. # drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
@dataclass @dataclass
...@@ -203,13 +203,6 @@ class SwinImageClassifierOutput(ModelOutput): ...@@ -203,13 +203,6 @@ class SwinImageClassifierOutput(ModelOutput):
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
def window_partition(input_feature, window_size): def window_partition(input_feature, window_size):
""" """
Partitions the given input into windows. Partitions the given input into windows.
...@@ -232,20 +225,6 @@ def window_reverse(windows, window_size, height, width): ...@@ -232,20 +225,6 @@ def window_reverse(windows, window_size, height, width):
return windows return windows
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = input.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return input * random_tensor
class SwinEmbeddings(nn.Module): class SwinEmbeddings(nn.Module):
""" """
Construct the patch and position embeddings. Optionally, also the mask token. Construct the patch and position embeddings. Optionally, also the mask token.
...@@ -254,12 +233,7 @@ class SwinEmbeddings(nn.Module): ...@@ -254,12 +233,7 @@ class SwinEmbeddings(nn.Module):
def __init__(self, config, use_mask_token=False): def __init__(self, config, use_mask_token=False):
super().__init__() super().__init__()
self.patch_embeddings = SwinPatchEmbeddings( self.patch_embeddings = SwinPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.embed_dim,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size self.patch_grid = self.patch_embeddings.grid_size
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
...@@ -295,20 +269,25 @@ class SwinEmbeddings(nn.Module): ...@@ -295,20 +269,25 @@ class SwinEmbeddings(nn.Module):
class SwinPatchEmbeddings(nn.Module): class SwinPatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(self, config):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width): def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0: if width % self.patch_size[1] != 0:
...@@ -320,7 +299,11 @@ class SwinPatchEmbeddings(nn.Module): ...@@ -320,7 +299,11 @@ class SwinPatchEmbeddings(nn.Module):
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, _, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
...@@ -385,16 +368,40 @@ class SwinPatchMerging(nn.Module): ...@@ -385,16 +368,40 @@ class SwinPatchMerging(nn.Module):
return input_feature return input_feature
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
class SwinDropPath(nn.Module): class SwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None, scale_by_keep=True): def __init__(self, drop_prob: Optional[float] = None) -> None:
super(SwinDropPath, self).__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, input): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(input, self.drop_prob, self.training, self.scale_by_keep) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class SwinSelfAttention(nn.Module): class SwinSelfAttention(nn.Module):
...@@ -408,7 +415,10 @@ class SwinSelfAttention(nn.Module): ...@@ -408,7 +415,10 @@ class SwinSelfAttention(nn.Module):
self.num_attention_heads = num_heads self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads) self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = to_2tuple(config.window_size) window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
...@@ -997,8 +1007,8 @@ class SwinModel(SwinPreTrainedModel): ...@@ -997,8 +1007,8 @@ class SwinModel(SwinPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM" "Swin Model with a decoder on top for masked image modeling, as proposed in"
" <https://arxiv.org/abs/2111.09886>`__.", " [SimMIM](https://arxiv.org/abs/2111.09886).",
SWIN_START_DOCSTRING, SWIN_START_DOCSTRING,
) )
class SwinForMaskedImageModeling(SwinPreTrainedModel): class SwinForMaskedImageModeling(SwinPreTrainedModel):
...@@ -1009,7 +1019,9 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1009,7 +1019,9 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1)) num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv2d(in_channels=num_features, out_channels=config.encoder_stride**2 * 3, kernel_size=1), nn.Conv2d(
in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
),
nn.PixelShuffle(config.encoder_stride), nn.PixelShuffle(config.encoder_stride),
) )
......
...@@ -63,7 +63,7 @@ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -63,7 +63,7 @@ TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Swin models at https://huggingface.co/models?filter=swin # See all Swin models at https://huggingface.co/models?filter=swin
] ]
# to_2tuple, drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow # drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow
# implementations of PyTorch functionalities in the timm library. # implementations of PyTorch functionalities in the timm library.
...@@ -208,13 +208,6 @@ class TFSwinImageClassifierOutput(ModelOutput): ...@@ -208,13 +208,6 @@ class TFSwinImageClassifierOutput(ModelOutput):
reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
# Copied from transformers.models.vit.modeling_tf_vit.to_2tuple
def to_2tuple(x) -> Tuple[Any, Any]:
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
""" """
Partitions the given input into windows. Partitions the given input into windows.
...@@ -270,13 +263,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer): ...@@ -270,13 +263,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None: def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.patch_embeddings = TFSwinPatchEmbeddings( self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings")
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.embed_dim,
name="patch_embeddings",
)
self.num_patches = self.patch_embeddings.num_patches self.num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size self.patch_grid = self.patch_embeddings.grid_size
self.embed_dim = config.embed_dim self.embed_dim = config.embed_dim
...@@ -329,20 +316,25 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer): ...@@ -329,20 +316,25 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
Image to Patch Embedding. Image to Patch Embedding.
""" """
def __init__( def __init__(self, config, **kwargs):
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768, **kwargs
) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.embed_dim
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = tf.keras.layers.Conv2D( self.projection = tf.keras.layers.Conv2D(
filters=embed_dim, kernel_size=self.patch_size, strides=self.patch_size, padding="valid", name="projection" filters=hidden_size,
kernel_size=self.patch_size,
strides=self.patch_size,
padding="valid",
name="projection",
) )
def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor: def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor:
...@@ -355,7 +347,11 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer): ...@@ -355,7 +347,11 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
return pixel_values return pixel_values
def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]: def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]:
_, _, height, width = shape_list(pixel_values) _, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
...@@ -460,7 +456,10 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -460,7 +456,10 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
self.num_attention_heads = num_heads self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads) self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = to_2tuple(config.window_size) window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = tf.range(self.window_size[0]) coords_h = tf.range(self.window_size[0])
...@@ -1252,7 +1251,7 @@ class TFSwinDecoder(tf.keras.layers.Layer): ...@@ -1252,7 +1251,7 @@ class TFSwinDecoder(tf.keras.layers.Layer):
def __init__(self, config: SwinConfig, **kwargs): def __init__(self, config: SwinConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.conv2d = tf.keras.layers.Conv2D( self.conv2d = tf.keras.layers.Conv2D(
filters=config.encoder_stride**2 * 3, kernel_size=1, strides=1, name="0" filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0"
) )
self._block_size = config.encoder_stride self._block_size = config.encoder_stride
self.pixel_shuffle = PixelShuffle(self._block_size, name="1") self.pixel_shuffle = PixelShuffle(self._block_size, name="1")
...@@ -1280,8 +1279,8 @@ class TFSwinDecoder(tf.keras.layers.Layer): ...@@ -1280,8 +1279,8 @@ class TFSwinDecoder(tf.keras.layers.Layer):
@add_start_docstrings( @add_start_docstrings(
"Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM" "Swin Model with a decoder on top for masked image modeling, as proposed in"
" <https://arxiv.org/abs/2111.09886>`__.", " [SimMIM](https://arxiv.org/abs/2111.09886).",
SWIN_START_DOCSTRING, SWIN_START_DOCSTRING,
) )
class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
......
...@@ -54,23 +54,24 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -54,23 +54,24 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Stochastic depth implementation # Copied from transformers.models.convnext.modeling_convnext.drop_path
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py def drop_path(input, drop_prob: float = 0.0, training: bool = False):
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion: Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = input.div(keep_prob) * random_tensor
return output return output
...@@ -78,13 +79,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): ...@@ -78,13 +79,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class VanDropPath(nn.Module): class VanDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class VanOverlappingPatchEmbedder(nn.Module): class VanOverlappingPatchEmbedder(nn.Module):
""" """
......
...@@ -82,13 +82,6 @@ class ViltForImagesAndTextClassificationOutput(ModelOutput): ...@@ -82,13 +82,6 @@ class ViltForImagesAndTextClassificationOutput(ModelOutput):
attentions: Optional[List[Tuple[torch.FloatTensor]]] = None attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
class ViltEmbeddings(nn.Module): class ViltEmbeddings(nn.Module):
""" """
Construct the text and patch embeddings. Construct the text and patch embeddings.
...@@ -105,12 +98,7 @@ class ViltEmbeddings(nn.Module): ...@@ -105,12 +98,7 @@ class ViltEmbeddings(nn.Module):
self.text_embeddings = TextEmbeddings(config) self.text_embeddings = TextEmbeddings(config)
# patch embeddings # patch embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = ViltPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
# modality type (text/patch) embeddings # modality type (text/patch) embeddings
...@@ -304,26 +292,32 @@ class TextEmbeddings(nn.Module): ...@@ -304,26 +292,32 @@ class TextEmbeddings(nn.Module):
return embeddings return embeddings
# Based on timm implementation, which can be found here: class ViltPatchEmbeddings(nn.Module):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. Image to Patch Embedding.
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(self, config):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
x = self.projection(pixel_values) x = self.projection(pixel_values)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment