Unverified Commit ed915cff authored by David Reguera's avatar David Reguera Committed by GitHub
Browse files

Add type hints for pytorch models (final batch) (#25750)



* Add type hints for table_transformer

* Add type hints to Timesformer model

* Add type hints to Timm Backbone model

* Add type hints to TVLT family models

* Add type hints to Vivit family models

* Use the typing instance instead of the python builtin.

* Fix the `replace_return_docstrings` decorator for Vivit model
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

---------
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent cb91ec67
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
...@@ -814,7 +814,7 @@ TABLE_TRANSFORMER_INPUTS_DOCSTRING = r""" ...@@ -814,7 +814,7 @@ TABLE_TRANSFORMER_INPUTS_DOCSTRING = r"""
Pixel values can be obtained using [`DetrImageProcessor`]. See [`DetrImageProcessor.__call__`] for details. Pixel values can be obtained using [`DetrImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): pixel_mask (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**), - 1 for pixels that are real (i.e. **not masked**),
...@@ -822,7 +822,7 @@ TABLE_TRANSFORMER_INPUTS_DOCSTRING = r""" ...@@ -822,7 +822,7 @@ TABLE_TRANSFORMER_INPUTS_DOCSTRING = r"""
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*): decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries. Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
...@@ -1190,16 +1190,16 @@ class TableTransformerModel(TableTransformerPreTrainedModel): ...@@ -1190,16 +1190,16 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
@replace_return_docstrings(output_type=TableTransformerModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TableTransformerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], TableTransformerModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1351,17 +1351,17 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel): ...@@ -1351,17 +1351,17 @@ class TableTransformerForObjectDetection(TableTransformerPreTrainedModel):
@replace_return_docstrings(output_type=TableTransformerObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TableTransformerObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[List[Dict]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], TableTransformerObjectDetectionOutput]:
r""" r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*): labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
......
...@@ -559,11 +559,11 @@ class TimesformerModel(TimesformerPreTrainedModel): ...@@ -559,11 +559,11 @@ class TimesformerModel(TimesformerPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r""" r"""
Returns: Returns:
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
...@@ -107,7 +109,12 @@ class TimmBackbone(PreTrainedModel, BackboneMixin): ...@@ -107,7 +109,12 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
pass pass
def forward( def forward(
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[BackboneOutput, Tuple[Tensor, ...]]: ) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = ( output_hidden_states = (
......
...@@ -715,16 +715,16 @@ class TvltModel(TvltPreTrainedModel): ...@@ -715,16 +715,16 @@ class TvltModel(TvltPreTrainedModel):
@replace_return_docstrings(output_type=TvltModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TvltModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
audio_values, audio_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask=None, audio_mask: Optional[torch.FloatTensor] = None,
mask_pixel=False, mask_pixel: bool = False,
mask_audio=False, mask_audio: bool = False,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
) -> Union[tuple, TvltModelOutput]: ) -> Union[Tuple[torch.FloatTensor], TvltModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1049,17 +1049,17 @@ class TvltForPreTraining(TvltPreTrainedModel): ...@@ -1049,17 +1049,17 @@ class TvltForPreTraining(TvltPreTrainedModel):
@replace_return_docstrings(output_type=TvltForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TvltForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
audio_values, audio_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask=None, audio_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
pixel_values_mixed=None, pixel_values_mixed: Optional[torch.FloatTensor] = None,
pixel_mask_mixed=None, pixel_mask_mixed: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
) -> Union[tuple, TvltForPreTrainingOutput]: ) -> Union[Tuple[torch.FloatTensor], TvltForPreTrainingOutput]:
r""" r"""
pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
...@@ -1250,15 +1250,15 @@ class TvltForAudioVisualClassification(TvltPreTrainedModel): ...@@ -1250,15 +1250,15 @@ class TvltForAudioVisualClassification(TvltPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
audio_values, audio_values: torch.FloatTensor,
pixel_mask=None, pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask=None, audio_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
) -> Union[tuple, SequenceClassifierOutput]: ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
......
...@@ -486,15 +486,15 @@ class VivitModel(VivitPreTrainedModel): ...@@ -486,15 +486,15 @@ class VivitModel(VivitPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -628,13 +628,13 @@ class VivitForVideoClassification(VivitPreTrainedModel): ...@@ -628,13 +628,13 @@ class VivitForVideoClassification(VivitPreTrainedModel):
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment