Unverified Commit 09178705 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve vision models (#17731)



* Improve vision models

* Add a lot of improvements

* Remove to_2tuple from swin tests

* Fix TF Swin

* Fix more tests

* Fix copies

* Improve more models

* Fix ViTMAE test

* Add channel check for TF models

* Add proper channel check for TF models

* Apply suggestion from code review

* Apply suggestions from code review

* Add channel check for Flax models, apply suggestion

* Fix bug

* Add tests for greyscale images

* Add test for interpolation of pos encodigns
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 893ab124
...@@ -84,7 +84,7 @@ VIT_INPUTS_DOCSTRING = r""" ...@@ -84,7 +84,7 @@ VIT_INPUTS_DOCSTRING = r"""
""" """
class FlaxPatchEmbeddings(nn.Module): class FlaxViTPatchEmbeddings(nn.Module):
config: ViTConfig config: ViTConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
...@@ -94,6 +94,7 @@ class FlaxPatchEmbeddings(nn.Module): ...@@ -94,6 +94,7 @@ class FlaxPatchEmbeddings(nn.Module):
patch_size = self.config.patch_size patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size) num_patches = (image_size // patch_size) * (image_size // patch_size)
self.num_patches = num_patches self.num_patches = num_patches
self.num_channels = self.config.num_channels
self.projection = nn.Conv( self.projection = nn.Conv(
self.config.hidden_size, self.config.hidden_size,
kernel_size=(patch_size, patch_size), kernel_size=(patch_size, patch_size),
...@@ -104,9 +105,14 @@ class FlaxPatchEmbeddings(nn.Module): ...@@ -104,9 +105,14 @@ class FlaxPatchEmbeddings(nn.Module):
) )
def __call__(self, pixel_values): def __call__(self, pixel_values):
x = self.projection(pixel_values) num_channels = pixel_values.shape[-1]
batch_size, _, _, channels = x.shape if num_channels != self.num_channels:
return jnp.reshape(x, (batch_size, -1, channels)) raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values)
batch_size, _, _, channels = embeddings.shape
return jnp.reshape(embeddings, (batch_size, -1, channels))
class FlaxViTEmbeddings(nn.Module): class FlaxViTEmbeddings(nn.Module):
...@@ -117,7 +123,7 @@ class FlaxViTEmbeddings(nn.Module): ...@@ -117,7 +123,7 @@ class FlaxViTEmbeddings(nn.Module):
def setup(self): def setup(self):
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype) self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = self.param( self.position_embeddings = self.param(
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
...@@ -420,7 +426,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel): ...@@ -420,7 +426,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
): ):
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None: if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3) input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
......
...@@ -52,19 +52,6 @@ _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224" ...@@ -52,19 +52,6 @@ _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFViTEmbeddings(tf.keras.layers.Layer): class TFViTEmbeddings(tf.keras.layers.Layer):
""" """
Construct the CLS token, position and patch embeddings. Construct the CLS token, position and patch embeddings.
...@@ -74,7 +61,7 @@ class TFViTEmbeddings(tf.keras.layers.Layer): ...@@ -74,7 +61,7 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs): def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings") self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config self.config = config
...@@ -103,19 +90,21 @@ class TFViTEmbeddings(tf.keras.layers.Layer): ...@@ -103,19 +90,21 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
""" """
batch_size, seq_len, dim = shape_list(embeddings) batch_size, seq_len, dim = shape_list(embeddings)
npatch = seq_len - 1 num_patches = seq_len - 1
_, N, _ = shape_list(self.position_embeddings) _, num_positions, _ = shape_list(self.position_embeddings)
N -= 1 num_positions -= 1
if npatch == N and height == width: if num_patches == num_positions and height == width:
return self.position_embeddings return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1] class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:] patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size h0 = height // self.config.patch_size
w0 = width // self.config.patch_size w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize( patch_pos_embed = tf.image.resize(
images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)), images=tf.reshape(
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
),
size=(h0, w0), size=(h0, w0),
method="bicubic", method="bicubic",
) )
...@@ -150,27 +139,31 @@ class TFViTEmbeddings(tf.keras.layers.Layer): ...@@ -150,27 +139,31 @@ class TFViTEmbeddings(tf.keras.layers.Layer):
# Based on timm implementation, which can be found here: # Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFPatchEmbeddings(tf.keras.layers.Layer): class TFViTPatchEmbeddings(tf.keras.layers.Layer):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__(self, config: ViTConfig, **kwargs): def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
image_size = to_2tuple(config.image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(config.patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_patches = num_patches self.num_patches = num_patches
self.num_channels = config.num_channels self.num_channels = num_channels
self.embed_dim = config.hidden_size
self.config = config self.config = config
self.projection = tf.keras.layers.Conv2D( self.projection = tf.keras.layers.Conv2D(
filters=self.embed_dim, filters=hidden_size,
kernel_size=patch_size, kernel_size=patch_size,
strides=self.patch_size, strides=patch_size,
padding="valid", padding="valid",
data_format="channels_last", data_format="channels_last",
use_bias=True, use_bias=True,
...@@ -183,8 +176,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer): ...@@ -183,8 +176,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
) -> tf.Tensor: ) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values) batch_size, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if not interpolate_pos_encoding: if not interpolate_pos_encoding:
if getattr(height, "numpy", None) and getattr(width, "numpy", None): if tf.executing_eagerly():
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f"Input image size ({height}*{width}) doesn't match model"
...@@ -201,9 +198,9 @@ class TFPatchEmbeddings(tf.keras.layers.Layer): ...@@ -201,9 +198,9 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
# Change the 2D spatial dimensions to a single temporal dimension. # Change the 2D spatial dimensions to a single temporal dimension.
# shape = (batch_size, num_patches, out_channels=embed_dim) # shape = (batch_size, num_patches, out_channels=embed_dim)
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
return x return embeddings
class TFViTSelfAttention(tf.keras.layers.Layer): class TFViTSelfAttention(tf.keras.layers.Layer):
......
...@@ -59,23 +59,9 @@ VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -59,23 +59,9 @@ VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class ViTEmbeddings(nn.Module): class ViTEmbeddings(nn.Module):
""" """
Construct the CLS token, position and patch embeddings. Optionally, also the mask token. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
""" """
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
...@@ -83,12 +69,7 @@ class ViTEmbeddings(nn.Module): ...@@ -83,12 +69,7 @@ class ViTEmbeddings(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = ViTPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
...@@ -103,9 +84,9 @@ class ViTEmbeddings(nn.Module): ...@@ -103,9 +84,9 @@ class ViTEmbeddings(nn.Module):
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
""" """
npatch = embeddings.shape[1] - 1 num_patches = embeddings.shape[1] - 1
N = self.position_embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1
if npatch == N and height == width: if num_patches == num_positions and height == width:
return self.position_embeddings return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0] class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:] patch_pos_embed = self.position_embeddings[:, 1:]
...@@ -115,9 +96,11 @@ class ViTEmbeddings(nn.Module): ...@@ -115,9 +96,11 @@ class ViTEmbeddings(nn.Module):
# we add a small number to avoid floating point error in the interpolation # we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8 # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1 h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate( patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), patch_pos_embed,
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic", mode="bicubic",
align_corners=False, align_corners=False,
) )
...@@ -134,9 +117,9 @@ class ViTEmbeddings(nn.Module): ...@@ -134,9 +117,9 @@ class ViTEmbeddings(nn.Module):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None: if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens # replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
...@@ -156,41 +139,42 @@ class ViTEmbeddings(nn.Module): ...@@ -156,41 +139,42 @@ class ViTEmbeddings(nn.Module):
return embeddings return embeddings
# Based on timm implementation, which can be found here: class ViTPatchEmbeddings(nn.Module):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__( def __init__(self, config):
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if not interpolate_pos_encoding: if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})." f" ({self.image_size[0]}*{self.image_size[1]})."
) )
x = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return embeddings
class ViTSelfAttention(nn.Module): class ViTSelfAttention(nn.Module):
...@@ -524,7 +508,7 @@ class ViTModel(ViTPreTrainedModel): ...@@ -524,7 +508,7 @@ class ViTModel(ViTPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self) -> PatchEmbeddings: def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
...@@ -613,8 +597,8 @@ class ViTPooler(nn.Module): ...@@ -613,8 +597,8 @@ class ViTPooler(nn.Module):
@add_start_docstrings( @add_start_docstrings(
"ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM" "ViT Model with a decoder on top for masked image modeling, as proposed in"
" <https://arxiv.org/abs/2111.09886>`__.", " [SimMIM](https://arxiv.org/abs/2111.09886).",
VIT_START_DOCSTRING, VIT_START_DOCSTRING,
) )
class ViTForMaskedImageModeling(ViTPreTrainedModel): class ViTForMaskedImageModeling(ViTPreTrainedModel):
...@@ -624,7 +608,11 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -624,7 +608,11 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True) self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1), nn.Conv2d(
in_channels=config.hidden_size,
out_channels=config.encoder_stride**2 * config.num_channels,
kernel_size=1,
),
nn.PixelShuffle(config.encoder_stride), nn.PixelShuffle(config.encoder_stride),
) )
......
...@@ -133,13 +133,6 @@ class TFViTMAEForPreTrainingOutput(ModelOutput): ...@@ -133,13 +133,6 @@ class TFViTMAEForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
# copied from transformers.models.vit.modeling_tf_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
""" """
Create 2D sin/cos positional embeddings. Create 2D sin/cos positional embeddings.
...@@ -212,7 +205,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer): ...@@ -212,7 +205,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs): def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings") self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
self.num_patches = self.patch_embeddings.num_patches self.num_patches = self.patch_embeddings.num_patches
self.config = config self.config = config
...@@ -297,30 +290,30 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer): ...@@ -297,30 +290,30 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
return embeddings, mask, ids_restore return embeddings, mask, ids_restore
# Based on timm implementation, which can be found here: class TFViTMAEPatchEmbeddings(tf.keras.layers.Layer):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFPatchEmbeddings(tf.keras.layers.Layer):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__(self, config: ViTMAEConfig, **kwargs): def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
image_size = to_2tuple(config.image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(config.patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_patches = num_patches self.num_patches = num_patches
self.num_channels = config.num_channels self.num_channels = num_channels
self.embed_dim = config.hidden_size
self.config = config self.config = config
self.projection = tf.keras.layers.Conv2D( self.projection = tf.keras.layers.Conv2D(
filters=self.embed_dim, filters=hidden_size,
kernel_size=self.patch_size, kernel_size=patch_size,
strides=self.patch_size, strides=patch_size,
padding="valid", padding="valid",
data_format="channels_last", data_format="channels_last",
kernel_initializer="glorot_uniform", # following torch.nn.Linear kernel_initializer="glorot_uniform", # following torch.nn.Linear
...@@ -330,7 +323,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer): ...@@ -330,7 +323,12 @@ class TFPatchEmbeddings(tf.keras.layers.Layer):
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values) batch_size, num_channels, height, width = shape_list(pixel_values)
if getattr(height, "numpy", None) and getattr(width, "numpy", None): if tf.executing_eagerly():
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration."
)
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f"Input image size ({height}*{width}) doesn't match model"
......
...@@ -135,13 +135,6 @@ class ViTMAEForPreTrainingOutput(ModelOutput): ...@@ -135,13 +135,6 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
# copied from transformers.models.vit.modeling_vit.to_2tuple ViT->ViTMAE
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
""" """
Create 2D sin/cos positional embeddings. Create 2D sin/cos positional embeddings.
...@@ -213,12 +206,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -213,12 +206,7 @@ class ViTMAEEmbeddings(nn.Module):
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = ViTMAEPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
self.num_patches = self.patch_embeddings.num_patches self.num_patches = self.patch_embeddings.num_patches
# fixed sin-cos embedding # fixed sin-cos embedding
self.position_embeddings = nn.Parameter( self.position_embeddings = nn.Parameter(
...@@ -291,27 +279,33 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -291,27 +279,33 @@ class ViTMAEEmbeddings(nn.Module):
return embeddings, mask, ids_restore return embeddings, mask, ids_restore
# Based on timm implementation, which can be found here: class ViTMAEPatchEmbeddings(nn.Module):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(self, config):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
......
...@@ -111,13 +111,6 @@ class YolosObjectDetectionOutput(ModelOutput): ...@@ -111,13 +111,6 @@ class YolosObjectDetectionOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
class YolosEmbeddings(nn.Module): class YolosEmbeddings(nn.Module):
""" """
Construct the CLS token, detection tokens, position and patch embeddings. Construct the CLS token, detection tokens, position and patch embeddings.
...@@ -129,12 +122,7 @@ class YolosEmbeddings(nn.Module): ...@@ -129,12 +122,7 @@ class YolosEmbeddings(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size)) self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
self.patch_embeddings = PatchEmbeddings( self.patch_embeddings = YolosPatchEmbeddings(config)
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter( self.position_embeddings = nn.Parameter(
torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size) torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
...@@ -228,32 +216,35 @@ class InterpolateMidPositionEmbeddings(nn.Module): ...@@ -228,32 +216,35 @@ class InterpolateMidPositionEmbeddings(nn.Module):
return scale_pos_embed return scale_pos_embed
# Based on timm implementation, which can be found here: class YolosPatchEmbeddings(nn.Module):
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
""" """
def __init__( def __init__(self, config):
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size, patch_size = config.image_size, config.patch_size
patch_size = to_2tuple(patch_size) num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings return embeddings
...@@ -620,7 +611,7 @@ class YolosModel(YolosPreTrainedModel): ...@@ -620,7 +611,7 @@ class YolosModel(YolosPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self) -> PatchEmbeddings: def get_input_embeddings(self) -> YolosPatchEmbeddings:
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import contextlib import contextlib
import inspect import inspect
import logging import logging
...@@ -1534,3 +1535,9 @@ def check_json_file_has_correct_format(file_path): ...@@ -1534,3 +1535,9 @@ def check_json_file_has_correct_format(file_path):
left_indent = len(lines[1]) - len(lines[1].lstrip()) left_indent = len(lines[1]) - len(lines[1].lstrip())
assert left_indent == 2 assert left_indent == 2
assert lines[-1].strip() == "}" assert lines[-1].strip() == "}"
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
...@@ -153,6 +153,16 @@ class BeitModelTester: ...@@ -153,6 +153,16 @@ class BeitModelTester:
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = BeitForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels): def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = BeitForSemanticSegmentation(config) model = BeitForSemanticSegmentation(config)
......
...@@ -105,7 +105,6 @@ class FlaxBeitModelTester(unittest.TestCase): ...@@ -105,7 +105,6 @@ class FlaxBeitModelTester(unittest.TestCase):
return config, pixel_values, labels return config, pixel_values, labels
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
model = FlaxBeitModel(config=config) model = FlaxBeitModel(config=config)
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
...@@ -121,6 +120,13 @@ class FlaxBeitModelTester(unittest.TestCase): ...@@ -121,6 +120,13 @@ class FlaxBeitModelTester(unittest.TestCase):
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = FlaxBeitForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
......
...@@ -37,10 +37,7 @@ if is_torch_available(): ...@@ -37,10 +37,7 @@ if is_torch_available():
Data2VecVisionForSemanticSegmentation, Data2VecVisionForSemanticSegmentation,
Data2VecVisionModel, Data2VecVisionModel,
) )
from transformers.models.data2vec.modeling_data2vec_vision import ( from transformers.models.data2vec.modeling_data2vec_vision import DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST
DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
to_2tuple,
)
if is_vision_available(): if is_vision_available():
...@@ -94,6 +91,10 @@ class Data2VecVisionModelTester: ...@@ -94,6 +91,10 @@ class Data2VecVisionModelTester:
self.out_indices = out_indices self.out_indices = out_indices
self.num_labels = num_labels self.num_labels = num_labels
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -131,9 +132,7 @@ class Data2VecVisionModelTester: ...@@ -131,9 +132,7 @@ class Data2VecVisionModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size) num_patches = (self.image_size // self.patch_size) ** 2
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
...@@ -286,109 +285,6 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -286,109 +285,6 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in Data2VecVision, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# Data2VecVision has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None): def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more # We override with a slightly higher tol value, as semseg models tend to diverge a bit more
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
......
...@@ -131,6 +131,25 @@ class DeiTModelTester: ...@@ -131,6 +131,25 @@ class DeiTModelTester:
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = DeiTForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
config.num_channels = 1
model = DeiTForMaskedImageModeling(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
model = DeiTForImageClassification(config) model = DeiTForImageClassification(config)
...@@ -139,6 +158,16 @@ class DeiTModelTester: ...@@ -139,6 +158,16 @@ class DeiTModelTester:
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = DeiTForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -208,6 +237,10 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -208,6 +237,10 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self): def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Swin model. """ """ Testing suite for the PyTorch Swin model. """
import collections
import inspect import inspect
import os import os
import pickle import pickle
...@@ -33,7 +34,7 @@ if is_torch_available(): ...@@ -33,7 +34,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
...@@ -141,6 +142,25 @@ class SwinModelTester: ...@@ -141,6 +142,25 @@ class SwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = SwinForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
config.num_channels = 1
model = SwinForMaskedImageModeling(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
model = SwinForImageClassification(config) model = SwinForImageClassification(config)
...@@ -149,6 +169,16 @@ class SwinModelTester: ...@@ -149,6 +169,16 @@ class SwinModelTester:
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = SwinForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -198,6 +228,14 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -198,6 +228,14 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_inputs_embeds(self): def test_inputs_embeds(self):
# Swin does not use inputs_embeds # Swin does not use inputs_embeds
pass pass
...@@ -299,7 +337,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -299,7 +337,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
# Swin has a different seq_length # Swin has a different seq_length
patch_size = to_2tuple(config.patch_size) patch_size = (
config.patch_size
if isinstance(config.patch_size, collections.abc.Iterable)
else (config.patch_size, config.patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
...@@ -323,7 +365,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -323,7 +365,11 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
def test_hidden_states_output(self): def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
image_size = to_2tuple(self.model_tester.image_size) image_size = (
self.model_tester.image_size
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
else (self.model_tester.image_size, self.model_tester.image_size)
)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
...@@ -339,8 +385,16 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -339,8 +385,16 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.patch_size = 3 config.patch_size = 3
image_size = to_2tuple(self.model_tester.image_size) image_size = (
patch_size = to_2tuple(config.patch_size) self.model_tester.image_size
if isinstance(self.model_tester.image_size, collections.abc.Iterable)
else (self.model_tester.image_size, self.model_tester.image_size)
)
patch_size = (
config.patch_size
if isinstance(config.patch_size, collections.abc.Iterable)
else (config.patch_size, config.patch_size)
)
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
...@@ -354,10 +408,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -354,10 +408,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config.output_hidden_states = True config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
from transformers import SwinConfig from transformers import SwinConfig
from transformers.testing_utils import require_tf, require_vision, slow from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple
from transformers.utils import cached_property, is_tf_available, is_vision_available from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -36,7 +36,6 @@ if is_tf_available(): ...@@ -36,7 +36,6 @@ if is_tf_available():
TFSwinForImageClassification, TFSwinForImageClassification,
TFSwinForMaskedImageModeling, TFSwinForMaskedImageModeling,
TFSwinModel, TFSwinModel,
to_2tuple,
) )
...@@ -141,12 +140,34 @@ class TFSwinModelTester: ...@@ -141,12 +140,34 @@ class TFSwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = TFSwinForMaskedImageModeling(config=config)
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
config.num_channels = 1
model = TFSwinForMaskedImageModeling(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
model = TFSwinForImageClassification(config) model = TFSwinForImageClassification(config)
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = TFSwinForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs config, pixel_values, labels = config_and_inputs
...@@ -192,6 +213,14 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -192,6 +213,14 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="Swin does not use inputs_embeds") @unittest.skip(reason="Swin does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
...@@ -336,10 +365,6 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -336,10 +365,6 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
config.output_hidden_states = True config.output_hidden_states = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -91,8 +91,7 @@ class FlaxViTModelTester(unittest.TestCase): ...@@ -91,8 +91,7 @@ class FlaxViTModelTester(unittest.TestCase):
return config, pixel_values return config, pixel_values
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values):
model = FlaxViTModel(config=config) model = FlaxViTModel(config=config)
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
...@@ -101,6 +100,19 @@ class FlaxViTModelTester(unittest.TestCase): ...@@ -101,6 +100,19 @@ class FlaxViTModelTester(unittest.TestCase):
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values):
config.num_labels = self.type_sequence_label_size
model = FlaxViTForImageClassification(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = FlaxViTForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -123,7 +135,15 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -123,7 +135,15 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# We neeed to override this test because ViT's forward signature is different than text models. def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
# We need to override this test because ViT's forward signature is different than text models.
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -133,6 +133,13 @@ class TFViTModelTester: ...@@ -133,6 +133,13 @@ class TFViTModelTester:
result = model(pixel_values, interpolate_pos_encoding=True, training=False) result = model(pixel_values, interpolate_pos_encoding=True, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = TFViTForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs config, pixel_values, labels = config_and_inputs
......
...@@ -120,6 +120,25 @@ class ViTModelTester: ...@@ -120,6 +120,25 @@ class ViTModelTester:
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = ViTForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
config.num_channels = 1
model = ViTForMaskedImageModeling(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
model = ViTForImageClassification(config) model = ViTForImageClassification(config)
...@@ -128,6 +147,16 @@ class ViTModelTester: ...@@ -128,6 +147,16 @@ class ViTModelTester:
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = ViTForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -197,6 +226,10 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -197,6 +226,10 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self): def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
...@@ -240,3 +273,30 @@ class ViTModelIntegrationTest(unittest.TestCase): ...@@ -240,3 +273,30 @@ class ViTModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device) expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# ViT models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = ViTModel.from_pretrained("facebook/dino-vits8").to(torch_device)
feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", size=480)
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 3601, 384))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
...@@ -38,7 +38,6 @@ if is_tf_available(): ...@@ -38,7 +38,6 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFViTMAEForPreTraining, TFViTMAEModel from transformers import TFViTMAEForPreTraining, TFViTMAEModel
from transformers.models.vit_mae.modeling_tf_vit_mae import to_2tuple
if is_vision_available(): if is_vision_available():
...@@ -67,6 +66,7 @@ class TFViTMAEModelTester: ...@@ -67,6 +66,7 @@ class TFViTMAEModelTester:
type_sequence_label_size=10, type_sequence_label_size=10,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
mask_ratio=0.6,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
...@@ -85,8 +85,14 @@ class TFViTMAEModelTester: ...@@ -85,8 +85,14 @@ class TFViTMAEModelTester:
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope self.scope = scope
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -116,29 +122,21 @@ class TFViTMAEModelTester: ...@@ -116,29 +122,21 @@ class TFViTMAEModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
mask_ratio=self.mask_ratio,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
model = TFViTMAEModel(config=config) model = TFViTMAEModel(config=config)
result = model(pixel_values, training=False) result = model(pixel_values, training=False)
# expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
def create_and_check_for_pretraining(self, config, pixel_values, labels): def create_and_check_for_pretraining(self, config, pixel_values, labels):
model = TFViTMAEForPreTraining(config) model = TFViTMAEForPreTraining(config)
result = model(pixel_values, training=False) result = model(pixel_values, training=False)
# expected sequence length = num_patches # expected sequence length = num_patches
image_size = to_2tuple(self.image_size) num_patches = (self.image_size // self.patch_size) ** 2
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
expected_seq_len = num_patches
expected_num_channels = self.patch_size**2 * self.num_channels expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
# test greyscale images # test greyscale images
config.num_channels = 1 config.num_channels = 1
...@@ -147,7 +145,7 @@ class TFViTMAEModelTester: ...@@ -147,7 +145,7 @@ class TFViTMAEModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, training=False) result = model(pixel_values, training=False)
expected_num_channels = self.patch_size**2 expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
...@@ -179,7 +177,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -179,7 +177,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
@unittest.skip(reason="ViTMAE does not use inputs_embeds") @unittest.skip(reason="ViTMAE does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
# ViTMAE does not use inputs_embeds
pass pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
...@@ -266,114 +263,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -266,114 +263,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
output_for_kw_input = model(**inputs_np, noise=noise) output_for_kw_input = model(**inputs_np, noise=noise)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input) self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# ViTMAE has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test # to generate masks during test
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
......
...@@ -35,7 +35,7 @@ if is_torch_available(): ...@@ -35,7 +35,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import ViTMAEForPreTraining, ViTMAEModel from transformers import ViTMAEForPreTraining, ViTMAEModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -64,6 +64,7 @@ class ViTMAEModelTester: ...@@ -64,6 +64,7 @@ class ViTMAEModelTester:
type_sequence_label_size=10, type_sequence_label_size=10,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
mask_ratio=0.6,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
...@@ -82,8 +83,14 @@ class ViTMAEModelTester: ...@@ -82,8 +83,14 @@ class ViTMAEModelTester:
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope self.scope = scope
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -109,6 +116,7 @@ class ViTMAEModelTester: ...@@ -109,6 +116,7 @@ class ViTMAEModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
mask_ratio=self.mask_ratio,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -116,26 +124,16 @@ class ViTMAEModelTester: ...@@ -116,26 +124,16 @@ class ViTMAEModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
def create_and_check_for_pretraining(self, config, pixel_values, labels): def create_and_check_for_pretraining(self, config, pixel_values, labels):
model = ViTMAEForPreTraining(config) model = ViTMAEForPreTraining(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# expected sequence length = num_patches num_patches = (self.image_size // self.patch_size) ** 2
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
expected_seq_len = num_patches
expected_num_channels = self.patch_size**2 * self.num_channels expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
# test greyscale images # test greyscale images
config.num_channels = 1 config.num_channels = 1
...@@ -145,7 +143,7 @@ class ViTMAEModelTester: ...@@ -145,7 +143,7 @@ class ViTMAEModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
expected_num_channels = self.patch_size**2 expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
...@@ -175,8 +173,8 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -175,8 +173,8 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
# ViTMAE does not use inputs_embeds
pass pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
...@@ -208,126 +206,6 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -208,126 +206,6 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_pretraining(*config_and_inputs) self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# ViTMAE has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise # overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test # to generate masks during test
def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict): def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
......
...@@ -31,7 +31,7 @@ if is_torch_available(): ...@@ -31,7 +31,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import YolosForObjectDetection, YolosModel from transformers import YolosForObjectDetection, YolosModel
from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -86,9 +86,7 @@ class YolosModelTester: ...@@ -86,9 +86,7 @@ class YolosModelTester:
self.num_detection_tokens = num_detection_tokens self.num_detection_tokens = num_detection_tokens
# we set the expected sequence length (which is used in several tests) # we set the expected sequence length (which is used in several tests)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
image_size = to_2tuple(self.image_size) num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.expected_seq_len = num_patches + 1 + self.num_detection_tokens self.expected_seq_len = num_patches + 1 + self.num_detection_tokens
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment