Unverified Commit 99e2982f authored by João Gustavo A. Amorim's avatar João Gustavo A. Amorim Committed by GitHub
Browse files

Add/type annotations/model vision (#16151)

* add types annotations for Beit (PyTorch)

* add types annotations for ViT (PyTorch)

* add types annotations for Deit (PyTorch)

* change Optional[bool] to bool into some places at Beit

* change Optional[bool] to bool into some places at ViT
parent 2410d0f8
......@@ -18,6 +18,7 @@
import collections.abc
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -99,7 +100,7 @@ def to_2tuple(x):
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
......@@ -122,11 +123,11 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
......@@ -141,7 +142,7 @@ class BeitEmbeddings(nn.Module):
"""
def __init__(self, config):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
......@@ -162,7 +163,7 @@ class BeitEmbeddings(nn.Module):
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None):
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()
......@@ -189,7 +190,9 @@ class PatchEmbeddings(nn.Module):
Image to Patch Embedding.
"""
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
def __init__(
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
) -> None:
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
......@@ -202,7 +205,7 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints
if height != self.image_size[0] or width != self.image_size[1]:
......@@ -215,7 +218,7 @@ class PatchEmbeddings(nn.Module):
class BeitSelfAttention(nn.Module):
def __init__(self, config, window_size=None):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
......@@ -243,7 +246,13 @@ class BeitSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
......@@ -291,12 +300,12 @@ class BeitSelfOutput(nn.Module):
layernorm applied before each block.
"""
def __init__(self, config):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor, gamma=None):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -304,7 +313,7 @@ class BeitSelfOutput(nn.Module):
class BeitAttention(nn.Module):
def __init__(self, config, window_size=None):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.attention = BeitSelfAttention(config, window_size=window_size)
self.output = BeitSelfOutput(config)
......@@ -328,7 +337,13 @@ class BeitAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
attention_output = self.output(self_outputs[0], hidden_states)
......@@ -338,7 +353,7 @@ class BeitAttention(nn.Module):
class BeitIntermediate(nn.Module):
def __init__(self, config):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
......@@ -346,7 +361,7 @@ class BeitIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
......@@ -354,12 +369,12 @@ class BeitIntermediate(nn.Module):
class BeitOutput(nn.Module):
def __init__(self, config):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -369,7 +384,7 @@ class BeitOutput(nn.Module):
class BeitLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config, window_size=None, drop_path_rate=0.0):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
......@@ -387,7 +402,13 @@ class BeitLayer(nn.Module):
else:
self.lambda_1, self.lambda_2 = None, None
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
head_mask,
......@@ -422,7 +443,7 @@ class BeitLayer(nn.Module):
class BeitRelativePositionBias(nn.Module):
def __init__(self, config, window_size):
def __init__(self, config: BeitConfig, window_size: tuple) -> None:
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
......@@ -451,7 +472,7 @@ class BeitRelativePositionBias(nn.Module):
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
......@@ -460,7 +481,7 @@ class BeitRelativePositionBias(nn.Module):
class BeitEncoder(nn.Module):
def __init__(self, config, window_size=None):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.config = config
if config.use_shared_relative_position_bias:
......@@ -484,12 +505,12 @@ class BeitEncoder(nn.Module):
def forward(
self,
hidden_states,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......@@ -606,7 +627,7 @@ BEIT_INPUTS_DOCSTRING = r"""
BEIT_START_DOCSTRING,
)
class BeitModel(BeitPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
super().__init__(config)
self.config = config
......@@ -643,13 +664,13 @@ class BeitModel(BeitPreTrainedModel):
)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BeitModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -691,13 +712,13 @@ class BeitModel(BeitPreTrainedModel):
class BeitPooler(nn.Module):
def __init__(self, config):
def __init__(self, config: BeitModel) -> None:
super().__init__()
self.layernorm = (
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.layernorm is not None:
# Mean pool the final hidden states of the patch tokens
patch_tokens = hidden_states[:, 1:, :]
......@@ -714,7 +735,7 @@ class BeitPooler(nn.Module):
BEIT_START_DOCSTRING,
)
class BeitForMaskedImageModeling(BeitPreTrainedModel):
def __init__(self, config):
def __init__(self, config: BeitModel) -> None:
super().__init__(config)
self.num_labels = config.num_labels
......@@ -731,14 +752,14 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
......@@ -814,7 +835,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
BEIT_START_DOCSTRING,
)
class BeitForImageClassification(BeitPreTrainedModel):
def __init__(self, config):
def __init__(self, config: BeitModel) -> None:
super().__init__(config)
self.num_labels = config.num_labels
......@@ -836,13 +857,13 @@ class BeitForImageClassification(BeitPreTrainedModel):
)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -904,7 +925,15 @@ class BeitConvModule(nn.Module):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int], str] = 0,
bias: bool = False,
dilation: Union[int, Tuple[int, int]] = 1,
) -> None:
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
......@@ -917,7 +946,7 @@ class BeitConvModule(nn.Module):
self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU()
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.conv(input)
output = self.bn(output)
output = self.activation(output)
......@@ -939,7 +968,7 @@ class BeitPyramidPoolingModule(nn.ModuleList):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, pool_scales, in_channels, channels, align_corners):
def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
super().__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
......@@ -953,7 +982,7 @@ class BeitPyramidPoolingModule(nn.ModuleList):
)
)
def forward(self, x):
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
......@@ -972,7 +1001,7 @@ class BeitUperHead(nn.Module):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config):
def __init__(self, config: BeitConfig) -> None:
super().__init__()
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
......@@ -1019,7 +1048,7 @@ class BeitUperHead(nn.Module):
return output
def forward(self, encoder_hidden_states):
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
# build laterals
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
......@@ -1064,7 +1093,9 @@ class BeitFCNHead(nn.Module):
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
def __init__(
self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1
) -> None:
super().__init__()
self.in_channels = config.hidden_size
self.channels = config.auxiliary_channels
......@@ -1096,7 +1127,7 @@ class BeitFCNHead(nn.Module):
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
def forward(self, encoder_hidden_states):
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
# just take the relevant feature maps
hidden_states = encoder_hidden_states[self.in_index]
output = self.convs(hidden_states)
......@@ -1113,7 +1144,7 @@ class BeitFCNHead(nn.Module):
BEIT_START_DOCSTRING,
)
class BeitForSemanticSegmentation(BeitPreTrainedModel):
def __init__(self, config):
def __init__(self, config: BeitConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1160,13 +1191,13 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SemanticSegmentationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
......
......@@ -18,7 +18,7 @@
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -77,7 +77,7 @@ class DeiTEmbeddings(nn.Module):
"""
def __init__(self, config, use_mask_token=False):
def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
......@@ -93,7 +93,7 @@ class DeiTEmbeddings(nn.Module):
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None):
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_len, _ = embeddings.size()
......@@ -117,7 +117,13 @@ class PatchEmbeddings(nn.Module):
"""
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
def __init__(
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
......@@ -128,7 +134,7 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints
if height != self.image_size[0] or width != self.image_size[1]:
......@@ -141,7 +147,7 @@ class PatchEmbeddings(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
class DeiTSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
......@@ -159,12 +165,14 @@ class DeiTSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
......@@ -205,12 +213,12 @@ class DeiTSelfOutput(nn.Module):
layernorm applied before each block.
"""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -220,13 +228,13 @@ class DeiTSelfOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
class DeiTAttention(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.attention = DeiTSelfAttention(config)
self.output = DeiTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
......@@ -244,7 +252,12 @@ class DeiTAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
......@@ -255,7 +268,7 @@ class DeiTAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
......@@ -263,7 +276,7 @@ class DeiTIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
......@@ -273,12 +286,12 @@ class DeiTIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
class DeiTOutput(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -291,7 +304,7 @@ class DeiTOutput(nn.Module):
class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
......@@ -301,7 +314,12 @@ class DeiTLayer(nn.Module):
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in DeiT, layernorm is applied before self-attention
head_mask,
......@@ -327,7 +345,7 @@ class DeiTLayer(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
class DeiTEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
......@@ -335,12 +353,12 @@ class DeiTEncoder(nn.Module):
def forward(
self,
hidden_states,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......@@ -395,7 +413,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
......@@ -407,7 +425,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None:
if isinstance(module, DeiTEncoder):
module.gradient_checkpointing = value
......@@ -451,7 +469,7 @@ DEIT_INPUTS_DOCSTRING = r"""
DEIT_START_DOCSTRING,
)
class DeiTModel(DeiTPreTrainedModel):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
super().__init__(config)
self.config = config
......@@ -464,7 +482,7 @@ class DeiTModel(DeiTPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
def get_input_embeddings(self) -> PatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
......@@ -486,12 +504,12 @@ class DeiTModel(DeiTPreTrainedModel):
)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -554,7 +572,7 @@ class DeiTPooler(nn.Module):
DEIT_START_DOCSTRING,
)
class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
def __init__(self, config):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
......@@ -571,13 +589,13 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
......@@ -662,7 +680,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
DEIT_START_DOCSTRING,
)
class DeiTForImageClassification(DeiTPreTrainedModel):
def __init__(self, config):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
......@@ -678,13 +696,13 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -811,7 +829,7 @@ class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
DEIT_START_DOCSTRING,
)
class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
def __init__(self, config):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
......@@ -838,12 +856,12 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deit(
......
......@@ -389,12 +389,12 @@ class ViltSelfOutput(nn.Module):
layernorm applied before each block.
"""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -438,7 +438,7 @@ class ViltAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
class ViltIntermediate(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
......@@ -446,7 +446,7 @@ class ViltIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
......@@ -456,12 +456,12 @@ class ViltIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
class ViltOutput(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......
......@@ -17,6 +17,7 @@
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -76,7 +77,7 @@ class ViTEmbeddings(nn.Module):
"""
def __init__(self, config, use_mask_token=False):
def __init__(self, config, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
......@@ -92,7 +93,7 @@ class ViTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings, height, width):
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
......@@ -123,7 +124,12 @@ class ViTEmbeddings(nn.Module):
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, bool_masked_pos=None, interpolate_pos_encoding=False):
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
......@@ -157,7 +163,13 @@ class PatchEmbeddings(nn.Module):
"""
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
def __init__(
self,
image_size: int = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
num_channels: int = 3,
embed_dim: int = 768,
):
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
......@@ -168,7 +180,7 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values, interpolate_pos_encoding=False):
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
......@@ -180,7 +192,7 @@ class PatchEmbeddings(nn.Module):
class ViTSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
......@@ -198,12 +210,14 @@ class ViTSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
......@@ -243,12 +257,12 @@ class ViTSelfOutput(nn.Module):
layernorm applied before each block.
"""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -257,13 +271,13 @@ class ViTSelfOutput(nn.Module):
class ViTAttention(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
......@@ -281,7 +295,12 @@ class ViTAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
......@@ -291,7 +310,7 @@ class ViTAttention(nn.Module):
class ViTIntermediate(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
......@@ -299,7 +318,7 @@ class ViTIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
......@@ -308,12 +327,12 @@ class ViTIntermediate(nn.Module):
class ViTOutput(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -325,7 +344,7 @@ class ViTOutput(nn.Module):
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
......@@ -335,7 +354,12 @@ class ViTLayer(nn.Module):
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
head_mask,
......@@ -360,7 +384,7 @@ class ViTLayer(nn.Module):
class ViTEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
......@@ -368,12 +392,12 @@ class ViTEncoder(nn.Module):
def forward(
self,
hidden_states,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......@@ -427,7 +451,7 @@ class ViTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
......@@ -439,7 +463,7 @@ class ViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder):
module.gradient_checkpointing = value
......@@ -485,7 +509,7 @@ VIT_INPUTS_DOCSTRING = r"""
VIT_START_DOCSTRING,
)
class ViTModel(ViTPreTrainedModel):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
super().__init__(config)
self.config = config
......@@ -498,10 +522,10 @@ class ViTModel(ViTPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
def get_input_embeddings(self) -> PatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
......@@ -520,13 +544,13 @@ class ViTModel(ViTPreTrainedModel):
)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -590,7 +614,7 @@ class ViTPooler(nn.Module):
VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
......@@ -607,14 +631,14 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
......@@ -700,7 +724,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
VIT_START_DOCSTRING,
)
class ViTForImageClassification(ViTPreTrainedModel):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__(config)
self.num_labels = config.num_labels
......@@ -722,14 +746,14 @@ class ViTForImageClassification(ViTPreTrainedModel):
)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
......@@ -19,7 +19,7 @@ import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -318,7 +318,7 @@ class PatchEmbeddings(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention
class ViTMAESelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
......@@ -336,12 +336,14 @@ class ViTMAESelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
......@@ -382,12 +384,12 @@ class ViTMAESelfOutput(nn.Module):
layernorm applied before each block.
"""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -397,13 +399,13 @@ class ViTMAESelfOutput(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class ViTMAEAttention(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.attention = ViTMAESelfAttention(config)
self.output = ViTMAESelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
......@@ -421,7 +423,12 @@ class ViTMAEAttention(nn.Module):
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
......@@ -432,7 +439,7 @@ class ViTMAEAttention(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate
class ViTMAEIntermediate(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
......@@ -440,7 +447,7 @@ class ViTMAEIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
......@@ -450,12 +457,12 @@ class ViTMAEIntermediate(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTOutput
class ViTMAEOutput(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -468,7 +475,7 @@ class ViTMAEOutput(nn.Module):
class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
......@@ -478,7 +485,12 @@ class ViTMAELayer(nn.Module):
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
head_mask,
......@@ -504,7 +516,7 @@ class ViTMAELayer(nn.Module):
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
class ViTMAEEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
......@@ -512,12 +524,12 @@ class ViTMAEEncoder(nn.Module):
def forward(
self,
hidden_states,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment