Unverified Commit ed915cff authored by David Reguera's avatar David Reguera Committed by GitHub
Browse files

Add type hints for pytorch models (final batch) (#25750)



* Add type hints for table_transformer

* Add type hints to Timesformer model

* Add type hints to Timm Backbone model

* Add type hints to TVLT family models

* Add type hints to Vivit family models

* Use the typing instance instead of the python builtin.

* Fix the `replace_return_docstrings` decorator for Vivit model
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

---------
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent cb91ec67
......@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
......@@ -814,7 +814,7 @@ TABLE_TRANSFORMER_INPUTS_DOCSTRING = r"""
Pixel values can be obtained using [`DetrImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
pixel_mask (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
......@@ -822,7 +822,7 @@ TABLE_TRANSFORMER_INPUTS_DOCSTRING = r"""
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*):
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
......@@ -1190,16 +1190,16 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
@replace_return_docstrings(output_type=TableTransformerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
pixel_mask=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TableTransformerModelOutput]:
r"""
Returns:
......@@ -1351,17 +1351,17 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
@replace_return_docstrings(output_type=TableTransformerObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
pixel_mask=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[List[Dict]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TableTransformerObjectDetectionOutput]:
r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
......
......@@ -559,11 +559,11 @@ class TimesformerModel(TimesformerPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
Returns:
......
......@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
......@@ -107,7 +109,12 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
pass
def forward(
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
......
......@@ -715,16 +715,16 @@ class TvltModel(TvltPreTrainedModel):
@replace_return_docstrings(output_type=TvltModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
audio_values,
pixel_mask=None,
audio_mask=None,
mask_pixel=False,
mask_audio=False,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[tuple, TvltModelOutput]:
pixel_values: torch.FloatTensor,
audio_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask: Optional[torch.FloatTensor] = None,
mask_pixel: bool = False,
mask_audio: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TvltModelOutput]:
r"""
Returns:
......@@ -1049,17 +1049,17 @@ class TvltForPreTraining(TvltPreTrainedModel):
@replace_return_docstrings(output_type=TvltForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
audio_values,
pixel_mask=None,
audio_mask=None,
labels=None,
pixel_values_mixed=None,
pixel_mask_mixed=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[tuple, TvltForPreTrainingOutput]:
pixel_values: torch.FloatTensor,
audio_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values_mixed: Optional[torch.FloatTensor] = None,
pixel_mask_mixed: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TvltForPreTrainingOutput]:
r"""
pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
......@@ -1250,15 +1250,15 @@ class TvltForAudioVisualClassification(TvltPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
audio_values,
pixel_mask=None,
audio_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
) -> Union[tuple, SequenceClassifierOutput]:
pixel_values: torch.FloatTensor,
audio_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
......
......@@ -486,15 +486,15 @@ class VivitModel(VivitPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
r"""
Returns:
......@@ -628,13 +628,13 @@ class VivitForVideoClassification(VivitPreTrainedModel):
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment