Unverified Commit 99e2982f authored by João Gustavo A. Amorim's avatar João Gustavo A. Amorim Committed by GitHub
Browse files

Add/type annotations/model vision (#16151)

* add types annotations for Beit (PyTorch)

* add types annotations for ViT (PyTorch)

* add types annotations for Deit (PyTorch)

* change Optional[bool] to bool into some places at Beit

* change Optional[bool] to bool into some places at ViT
parent 2410d0f8
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -99,7 +100,7 @@ def to_2tuple(x): ...@@ -99,7 +100,7 @@ def to_2tuple(x):
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False): def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
...@@ -122,11 +123,11 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): ...@@ -122,11 +123,11 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None): def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str: def extra_repr(self) -> str:
...@@ -141,7 +142,7 @@ class BeitEmbeddings(nn.Module): ...@@ -141,7 +142,7 @@ class BeitEmbeddings(nn.Module):
""" """
def __init__(self, config): def __init__(self, config: BeitConfig) -> None:
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
...@@ -162,7 +163,7 @@ class BeitEmbeddings(nn.Module): ...@@ -162,7 +163,7 @@ class BeitEmbeddings(nn.Module):
self.position_embeddings = None self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None): def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -189,7 +190,9 @@ class PatchEmbeddings(nn.Module): ...@@ -189,7 +190,9 @@ class PatchEmbeddings(nn.Module):
Image to Patch Embedding. Image to Patch Embedding.
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
) -> None:
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
...@@ -202,7 +205,7 @@ class PatchEmbeddings(nn.Module): ...@@ -202,7 +205,7 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints # FIXME look at relaxing size constraints
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
...@@ -215,7 +218,7 @@ class PatchEmbeddings(nn.Module): ...@@ -215,7 +218,7 @@ class PatchEmbeddings(nn.Module):
class BeitSelfAttention(nn.Module): class BeitSelfAttention(nn.Module):
def __init__(self, config, window_size=None): def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -243,7 +246,13 @@ class BeitSelfAttention(nn.Module): ...@@ -243,7 +246,13 @@ class BeitSelfAttention(nn.Module):
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states))
...@@ -291,12 +300,12 @@ class BeitSelfOutput(nn.Module): ...@@ -291,12 +300,12 @@ class BeitSelfOutput(nn.Module):
layernorm applied before each block. layernorm applied before each block.
""" """
def __init__(self, config): def __init__(self, config: BeitConfig) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor, gamma=None): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -304,7 +313,7 @@ class BeitSelfOutput(nn.Module): ...@@ -304,7 +313,7 @@ class BeitSelfOutput(nn.Module):
class BeitAttention(nn.Module): class BeitAttention(nn.Module):
def __init__(self, config, window_size=None): def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__() super().__init__()
self.attention = BeitSelfAttention(config, window_size=window_size) self.attention = BeitSelfAttention(config, window_size=window_size)
self.output = BeitSelfOutput(config) self.output = BeitSelfOutput(config)
...@@ -328,7 +337,13 @@ class BeitAttention(nn.Module): ...@@ -328,7 +337,13 @@ class BeitAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -338,7 +353,7 @@ class BeitAttention(nn.Module): ...@@ -338,7 +353,7 @@ class BeitAttention(nn.Module):
class BeitIntermediate(nn.Module): class BeitIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config: BeitConfig) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -346,7 +361,7 @@ class BeitIntermediate(nn.Module): ...@@ -346,7 +361,7 @@ class BeitIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
...@@ -354,12 +369,12 @@ class BeitIntermediate(nn.Module): ...@@ -354,12 +369,12 @@ class BeitIntermediate(nn.Module):
class BeitOutput(nn.Module): class BeitOutput(nn.Module):
def __init__(self, config): def __init__(self, config: BeitConfig) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -369,7 +384,7 @@ class BeitOutput(nn.Module): ...@@ -369,7 +384,7 @@ class BeitOutput(nn.Module):
class BeitLayer(nn.Module): class BeitLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config, window_size=None, drop_path_rate=0.0): def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
...@@ -387,7 +402,13 @@ class BeitLayer(nn.Module): ...@@ -387,7 +402,13 @@ class BeitLayer(nn.Module):
else: else:
self.lambda_1, self.lambda_2 = None, None self.lambda_1, self.lambda_2 = None, None
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
head_mask, head_mask,
...@@ -422,7 +443,7 @@ class BeitLayer(nn.Module): ...@@ -422,7 +443,7 @@ class BeitLayer(nn.Module):
class BeitRelativePositionBias(nn.Module): class BeitRelativePositionBias(nn.Module):
def __init__(self, config, window_size): def __init__(self, config: BeitConfig, window_size: tuple) -> None:
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
...@@ -451,7 +472,7 @@ class BeitRelativePositionBias(nn.Module): ...@@ -451,7 +472,7 @@ class BeitRelativePositionBias(nn.Module):
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index)
def forward(self): def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH ) # Wh*Ww,Wh*Ww,nH
...@@ -460,7 +481,7 @@ class BeitRelativePositionBias(nn.Module): ...@@ -460,7 +481,7 @@ class BeitRelativePositionBias(nn.Module):
class BeitEncoder(nn.Module): class BeitEncoder(nn.Module):
def __init__(self, config, window_size=None): def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
if config.use_shared_relative_position_bias: if config.use_shared_relative_position_bias:
...@@ -484,12 +505,12 @@ class BeitEncoder(nn.Module): ...@@ -484,12 +505,12 @@ class BeitEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -606,7 +627,7 @@ BEIT_INPUTS_DOCSTRING = r""" ...@@ -606,7 +627,7 @@ BEIT_INPUTS_DOCSTRING = r"""
BEIT_START_DOCSTRING, BEIT_START_DOCSTRING,
) )
class BeitModel(BeitPreTrainedModel): class BeitModel(BeitPreTrainedModel):
def __init__(self, config, add_pooling_layer=True): def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -643,13 +664,13 @@ class BeitModel(BeitPreTrainedModel): ...@@ -643,13 +664,13 @@ class BeitModel(BeitPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, BeitModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -691,13 +712,13 @@ class BeitModel(BeitPreTrainedModel): ...@@ -691,13 +712,13 @@ class BeitModel(BeitPreTrainedModel):
class BeitPooler(nn.Module): class BeitPooler(nn.Module):
def __init__(self, config): def __init__(self, config: BeitModel) -> None:
super().__init__() super().__init__()
self.layernorm = ( self.layernorm = (
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
) )
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.layernorm is not None: if self.layernorm is not None:
# Mean pool the final hidden states of the patch tokens # Mean pool the final hidden states of the patch tokens
patch_tokens = hidden_states[:, 1:, :] patch_tokens = hidden_states[:, 1:, :]
...@@ -714,7 +735,7 @@ class BeitPooler(nn.Module): ...@@ -714,7 +735,7 @@ class BeitPooler(nn.Module):
BEIT_START_DOCSTRING, BEIT_START_DOCSTRING,
) )
class BeitForMaskedImageModeling(BeitPreTrainedModel): class BeitForMaskedImageModeling(BeitPreTrainedModel):
def __init__(self, config): def __init__(self, config: BeitModel) -> None:
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -731,14 +752,14 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel): ...@@ -731,14 +752,14 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, MaskedLMOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -814,7 +835,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel): ...@@ -814,7 +835,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
BEIT_START_DOCSTRING, BEIT_START_DOCSTRING,
) )
class BeitForImageClassification(BeitPreTrainedModel): class BeitForImageClassification(BeitPreTrainedModel):
def __init__(self, config): def __init__(self, config: BeitModel) -> None:
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -836,13 +857,13 @@ class BeitForImageClassification(BeitPreTrainedModel): ...@@ -836,13 +857,13 @@ class BeitForImageClassification(BeitPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
...@@ -904,7 +925,15 @@ class BeitConvModule(nn.Module): ...@@ -904,7 +925,15 @@ class BeitConvModule(nn.Module):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
""" """
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1): def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int], str] = 0,
bias: bool = False,
dilation: Union[int, Tuple[int, int]] = 1,
) -> None:
super().__init__() super().__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels=in_channels, in_channels=in_channels,
...@@ -917,7 +946,7 @@ class BeitConvModule(nn.Module): ...@@ -917,7 +946,7 @@ class BeitConvModule(nn.Module):
self.bn = nn.BatchNorm2d(out_channels) self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU() self.activation = nn.ReLU()
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.conv(input) output = self.conv(input)
output = self.bn(output) output = self.bn(output)
output = self.activation(output) output = self.activation(output)
...@@ -939,7 +968,7 @@ class BeitPyramidPoolingModule(nn.ModuleList): ...@@ -939,7 +968,7 @@ class BeitPyramidPoolingModule(nn.ModuleList):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
""" """
def __init__(self, pool_scales, in_channels, channels, align_corners): def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
super().__init__() super().__init__()
self.pool_scales = pool_scales self.pool_scales = pool_scales
self.align_corners = align_corners self.align_corners = align_corners
...@@ -953,7 +982,7 @@ class BeitPyramidPoolingModule(nn.ModuleList): ...@@ -953,7 +982,7 @@ class BeitPyramidPoolingModule(nn.ModuleList):
) )
) )
def forward(self, x): def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = [] ppm_outs = []
for ppm in self: for ppm in self:
ppm_out = ppm(x) ppm_out = ppm(x)
...@@ -972,7 +1001,7 @@ class BeitUperHead(nn.Module): ...@@ -972,7 +1001,7 @@ class BeitUperHead(nn.Module):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
""" """
def __init__(self, config): def __init__(self, config: BeitConfig) -> None:
super().__init__() super().__init__()
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
...@@ -1019,7 +1048,7 @@ class BeitUperHead(nn.Module): ...@@ -1019,7 +1048,7 @@ class BeitUperHead(nn.Module):
return output return output
def forward(self, encoder_hidden_states): def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
# build laterals # build laterals
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
...@@ -1064,7 +1093,9 @@ class BeitFCNHead(nn.Module): ...@@ -1064,7 +1093,9 @@ class BeitFCNHead(nn.Module):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
""" """
def __init__(self, config, in_index=2, kernel_size=3, dilation=1): def __init__(
self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1
) -> None:
super().__init__() super().__init__()
self.in_channels = config.hidden_size self.in_channels = config.hidden_size
self.channels = config.auxiliary_channels self.channels = config.auxiliary_channels
...@@ -1096,7 +1127,7 @@ class BeitFCNHead(nn.Module): ...@@ -1096,7 +1127,7 @@ class BeitFCNHead(nn.Module):
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
def forward(self, encoder_hidden_states): def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
# just take the relevant feature maps # just take the relevant feature maps
hidden_states = encoder_hidden_states[self.in_index] hidden_states = encoder_hidden_states[self.in_index]
output = self.convs(hidden_states) output = self.convs(hidden_states)
...@@ -1113,7 +1144,7 @@ class BeitFCNHead(nn.Module): ...@@ -1113,7 +1144,7 @@ class BeitFCNHead(nn.Module):
BEIT_START_DOCSTRING, BEIT_START_DOCSTRING,
) )
class BeitForSemanticSegmentation(BeitPreTrainedModel): class BeitForSemanticSegmentation(BeitPreTrainedModel):
def __init__(self, config): def __init__(self, config: BeitConfig) -> None:
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1160,13 +1191,13 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1160,13 +1191,13 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, SemanticSegmentationModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -77,7 +77,7 @@ class DeiTEmbeddings(nn.Module): ...@@ -77,7 +77,7 @@ class DeiTEmbeddings(nn.Module):
""" """
def __init__(self, config, use_mask_token=False): def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
...@@ -93,7 +93,7 @@ class DeiTEmbeddings(nn.Module): ...@@ -93,7 +93,7 @@ class DeiTEmbeddings(nn.Module):
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None): def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -117,7 +117,13 @@ class PatchEmbeddings(nn.Module): ...@@ -117,7 +117,13 @@ class PatchEmbeddings(nn.Module):
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
...@@ -128,7 +134,7 @@ class PatchEmbeddings(nn.Module): ...@@ -128,7 +134,7 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints # FIXME look at relaxing size constraints
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
...@@ -141,7 +147,7 @@ class PatchEmbeddings(nn.Module): ...@@ -141,7 +147,7 @@ class PatchEmbeddings(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
class DeiTSelfAttention(nn.Module): class DeiTSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -159,12 +165,14 @@ class DeiTSelfAttention(nn.Module): ...@@ -159,12 +165,14 @@ class DeiTSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states))
...@@ -205,12 +213,12 @@ class DeiTSelfOutput(nn.Module): ...@@ -205,12 +213,12 @@ class DeiTSelfOutput(nn.Module):
layernorm applied before each block. layernorm applied before each block.
""" """
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -220,13 +228,13 @@ class DeiTSelfOutput(nn.Module): ...@@ -220,13 +228,13 @@ class DeiTSelfOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
class DeiTAttention(nn.Module): class DeiTAttention(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.attention = DeiTSelfAttention(config) self.attention = DeiTSelfAttention(config)
self.output = DeiTSelfOutput(config) self.output = DeiTSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(
...@@ -244,7 +252,12 @@ class DeiTAttention(nn.Module): ...@@ -244,7 +252,12 @@ class DeiTAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions) self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -255,7 +268,7 @@ class DeiTAttention(nn.Module): ...@@ -255,7 +268,7 @@ class DeiTAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module): class DeiTIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -263,7 +276,7 @@ class DeiTIntermediate(nn.Module): ...@@ -263,7 +276,7 @@ class DeiTIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
...@@ -273,12 +286,12 @@ class DeiTIntermediate(nn.Module): ...@@ -273,12 +286,12 @@ class DeiTIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
class DeiTOutput(nn.Module): class DeiTOutput(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -291,7 +304,7 @@ class DeiTOutput(nn.Module): ...@@ -291,7 +304,7 @@ class DeiTOutput(nn.Module):
class DeiTLayer(nn.Module): class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
...@@ -301,7 +314,12 @@ class DeiTLayer(nn.Module): ...@@ -301,7 +314,12 @@ class DeiTLayer(nn.Module):
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in DeiT, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in DeiT, layernorm is applied before self-attention
head_mask, head_mask,
...@@ -327,7 +345,7 @@ class DeiTLayer(nn.Module): ...@@ -327,7 +345,7 @@ class DeiTLayer(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
class DeiTEncoder(nn.Module): class DeiTEncoder(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
...@@ -335,12 +353,12 @@ class DeiTEncoder(nn.Module): ...@@ -335,12 +353,12 @@ class DeiTEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -395,7 +413,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -395,7 +413,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
...@@ -407,7 +425,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -407,7 +425,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None:
if isinstance(module, DeiTEncoder): if isinstance(module, DeiTEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -451,7 +469,7 @@ DEIT_INPUTS_DOCSTRING = r""" ...@@ -451,7 +469,7 @@ DEIT_INPUTS_DOCSTRING = r"""
DEIT_START_DOCSTRING, DEIT_START_DOCSTRING,
) )
class DeiTModel(DeiTPreTrainedModel): class DeiTModel(DeiTPreTrainedModel):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False): def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -464,7 +482,7 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -464,7 +482,7 @@ class DeiTModel(DeiTPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self) -> PatchEmbeddings:
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
...@@ -486,12 +504,12 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -486,12 +504,12 @@ class DeiTModel(DeiTPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -554,7 +572,7 @@ class DeiTPooler(nn.Module): ...@@ -554,7 +572,7 @@ class DeiTPooler(nn.Module):
DEIT_START_DOCSTRING, DEIT_START_DOCSTRING,
) )
class DeiTForMaskedImageModeling(DeiTPreTrainedModel): class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
def __init__(self, config): def __init__(self, config: DeiTConfig) -> None:
super().__init__(config) super().__init__(config)
self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True) self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
...@@ -571,13 +589,13 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -571,13 +589,13 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, MaskedLMOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -662,7 +680,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -662,7 +680,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
DEIT_START_DOCSTRING, DEIT_START_DOCSTRING,
) )
class DeiTForImageClassification(DeiTPreTrainedModel): class DeiTForImageClassification(DeiTPreTrainedModel):
def __init__(self, config): def __init__(self, config: DeiTConfig) -> None:
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -678,13 +696,13 @@ class DeiTForImageClassification(DeiTPreTrainedModel): ...@@ -678,13 +696,13 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
...@@ -811,7 +829,7 @@ class DeiTForImageClassificationWithTeacherOutput(ModelOutput): ...@@ -811,7 +829,7 @@ class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
DEIT_START_DOCSTRING, DEIT_START_DOCSTRING,
) )
class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
def __init__(self, config): def __init__(self, config: DeiTConfig) -> None:
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -838,12 +856,12 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): ...@@ -838,12 +856,12 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deit( outputs = self.deit(
......
...@@ -389,12 +389,12 @@ class ViltSelfOutput(nn.Module): ...@@ -389,12 +389,12 @@ class ViltSelfOutput(nn.Module):
layernorm applied before each block. layernorm applied before each block.
""" """
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -438,7 +438,7 @@ class ViltAttention(nn.Module): ...@@ -438,7 +438,7 @@ class ViltAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
class ViltIntermediate(nn.Module): class ViltIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -446,7 +446,7 @@ class ViltIntermediate(nn.Module): ...@@ -446,7 +446,7 @@ class ViltIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
...@@ -456,12 +456,12 @@ class ViltIntermediate(nn.Module): ...@@ -456,12 +456,12 @@ class ViltIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
class ViltOutput(nn.Module): class ViltOutput(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections.abc import collections.abc
import math import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -76,7 +77,7 @@ class ViTEmbeddings(nn.Module): ...@@ -76,7 +77,7 @@ class ViTEmbeddings(nn.Module):
""" """
def __init__(self, config, use_mask_token=False): def __init__(self, config, use_mask_token: bool = False) -> None:
super().__init__() super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
...@@ -92,7 +93,7 @@ class ViTEmbeddings(nn.Module): ...@@ -92,7 +93,7 @@ class ViTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config self.config = config
def interpolate_pos_encoding(self, embeddings, height, width): def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
""" """
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images. resolution images.
...@@ -123,7 +124,12 @@ class ViTEmbeddings(nn.Module): ...@@ -123,7 +124,12 @@ class ViTEmbeddings(nn.Module):
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, bool_masked_pos=None, interpolate_pos_encoding=False): def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
...@@ -157,7 +163,13 @@ class PatchEmbeddings(nn.Module): ...@@ -157,7 +163,13 @@ class PatchEmbeddings(nn.Module):
""" """
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): def __init__(
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
):
super().__init__() super().__init__()
image_size = to_2tuple(image_size) image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
...@@ -168,7 +180,7 @@ class PatchEmbeddings(nn.Module): ...@@ -168,7 +180,7 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values, interpolate_pos_encoding=False): def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if not interpolate_pos_encoding: if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
...@@ -180,7 +192,7 @@ class PatchEmbeddings(nn.Module): ...@@ -180,7 +192,7 @@ class PatchEmbeddings(nn.Module):
class ViTSelfAttention(nn.Module): class ViTSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -198,12 +210,14 @@ class ViTSelfAttention(nn.Module): ...@@ -198,12 +210,14 @@ class ViTSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states))
...@@ -243,12 +257,12 @@ class ViTSelfOutput(nn.Module): ...@@ -243,12 +257,12 @@ class ViTSelfOutput(nn.Module):
layernorm applied before each block. layernorm applied before each block.
""" """
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -257,13 +271,13 @@ class ViTSelfOutput(nn.Module): ...@@ -257,13 +271,13 @@ class ViTSelfOutput(nn.Module):
class ViTAttention(nn.Module): class ViTAttention(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.attention = ViTSelfAttention(config) self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config) self.output = ViTSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(
...@@ -281,7 +295,12 @@ class ViTAttention(nn.Module): ...@@ -281,7 +295,12 @@ class ViTAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions) self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -291,7 +310,7 @@ class ViTAttention(nn.Module): ...@@ -291,7 +310,7 @@ class ViTAttention(nn.Module):
class ViTIntermediate(nn.Module): class ViTIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -299,7 +318,7 @@ class ViTIntermediate(nn.Module): ...@@ -299,7 +318,7 @@ class ViTIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
...@@ -308,12 +327,12 @@ class ViTIntermediate(nn.Module): ...@@ -308,12 +327,12 @@ class ViTIntermediate(nn.Module):
class ViTOutput(nn.Module): class ViTOutput(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -325,7 +344,7 @@ class ViTOutput(nn.Module): ...@@ -325,7 +344,7 @@ class ViTOutput(nn.Module):
class ViTLayer(nn.Module): class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
...@@ -335,7 +354,12 @@ class ViTLayer(nn.Module): ...@@ -335,7 +354,12 @@ class ViTLayer(nn.Module):
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
head_mask, head_mask,
...@@ -360,7 +384,7 @@ class ViTLayer(nn.Module): ...@@ -360,7 +384,7 @@ class ViTLayer(nn.Module):
class ViTEncoder(nn.Module): class ViTEncoder(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
...@@ -368,12 +392,12 @@ class ViTEncoder(nn.Module): ...@@ -368,12 +392,12 @@ class ViTEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -427,7 +451,7 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -427,7 +451,7 @@ class ViTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
...@@ -439,7 +463,7 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -439,7 +463,7 @@ class ViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder): if isinstance(module, ViTEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -485,7 +509,7 @@ VIT_INPUTS_DOCSTRING = r""" ...@@ -485,7 +509,7 @@ VIT_INPUTS_DOCSTRING = r"""
VIT_START_DOCSTRING, VIT_START_DOCSTRING,
) )
class ViTModel(ViTPreTrainedModel): class ViTModel(ViTPreTrainedModel):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False): def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -498,10 +522,10 @@ class ViTModel(ViTPreTrainedModel): ...@@ -498,10 +522,10 @@ class ViTModel(ViTPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self) -> PatchEmbeddings:
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
""" """
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel class PreTrainedModel
...@@ -520,13 +544,13 @@ class ViTModel(ViTPreTrainedModel): ...@@ -520,13 +544,13 @@ class ViTModel(ViTPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding=None, interpolate_pos_encoding: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -590,7 +614,7 @@ class ViTPooler(nn.Module): ...@@ -590,7 +614,7 @@ class ViTPooler(nn.Module):
VIT_START_DOCSTRING, VIT_START_DOCSTRING,
) )
class ViTForMaskedImageModeling(ViTPreTrainedModel): class ViTForMaskedImageModeling(ViTPreTrainedModel):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__(config) super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True) self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
...@@ -607,14 +631,14 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -607,14 +631,14 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding=None, interpolate_pos_encoding: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, MaskedLMOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -700,7 +724,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -700,7 +724,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
VIT_START_DOCSTRING, VIT_START_DOCSTRING,
) )
class ViTForImageClassification(ViTPreTrainedModel): class ViTForImageClassification(ViTPreTrainedModel):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -722,14 +746,14 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -722,14 +746,14 @@ class ViTForImageClassification(ViTPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding=None, interpolate_pos_encoding: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
...@@ -19,7 +19,7 @@ import collections.abc ...@@ -19,7 +19,7 @@ import collections.abc
import math import math
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -318,7 +318,7 @@ class PatchEmbeddings(nn.Module): ...@@ -318,7 +318,7 @@ class PatchEmbeddings(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention
class ViTMAESelfAttention(nn.Module): class ViTMAESelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -336,12 +336,14 @@ class ViTMAESelfAttention(nn.Module): ...@@ -336,12 +336,14 @@ class ViTMAESelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states))
...@@ -382,12 +384,12 @@ class ViTMAESelfOutput(nn.Module): ...@@ -382,12 +384,12 @@ class ViTMAESelfOutput(nn.Module):
layernorm applied before each block. layernorm applied before each block.
""" """
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -397,13 +399,13 @@ class ViTMAESelfOutput(nn.Module): ...@@ -397,13 +399,13 @@ class ViTMAESelfOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class ViTMAEAttention(nn.Module): class ViTMAEAttention(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.attention = ViTMAESelfAttention(config) self.attention = ViTMAESelfAttention(config)
self.output = ViTMAESelfOutput(config) self.output = ViTMAESelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(
...@@ -421,7 +423,12 @@ class ViTMAEAttention(nn.Module): ...@@ -421,7 +423,12 @@ class ViTMAEAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions) self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -432,7 +439,7 @@ class ViTMAEAttention(nn.Module): ...@@ -432,7 +439,7 @@ class ViTMAEAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate # Copied from transformers.models.vit.modeling_vit.ViTIntermediate
class ViTMAEIntermediate(nn.Module): class ViTMAEIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
...@@ -440,7 +447,7 @@ class ViTMAEIntermediate(nn.Module): ...@@ -440,7 +447,7 @@ class ViTMAEIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
...@@ -450,12 +457,12 @@ class ViTMAEIntermediate(nn.Module): ...@@ -450,12 +457,12 @@ class ViTMAEIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput # Copied from transformers.models.vit.modeling_vit.ViTOutput
class ViTMAEOutput(nn.Module): class ViTMAEOutput(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -468,7 +475,7 @@ class ViTMAEOutput(nn.Module): ...@@ -468,7 +475,7 @@ class ViTMAEOutput(nn.Module):
class ViTMAELayer(nn.Module): class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
...@@ -478,7 +485,12 @@ class ViTMAELayer(nn.Module): ...@@ -478,7 +485,12 @@ class ViTMAELayer(nn.Module):
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
head_mask, head_mask,
...@@ -504,7 +516,7 @@ class ViTMAELayer(nn.Module): ...@@ -504,7 +516,7 @@ class ViTMAELayer(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
class ViTMAEEncoder(nn.Module): class ViTMAEEncoder(nn.Module):
def __init__(self, config): def __init__(self, config) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
...@@ -512,12 +524,12 @@ class ViTMAEEncoder(nn.Module): ...@@ -512,12 +524,12 @@ class ViTMAEEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment