Unverified Commit 46d09410 authored by Ian Castillo's avatar Ian Castillo Committed by GitHub
Browse files

Add type hints for ViLT models (#18577)

* Add type hints for Vilt models

* Add missing return type for TokenClassification class
parent bce36ee0
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -761,19 +761,19 @@ class ViltModel(ViltPreTrainedModel): ...@@ -761,19 +761,19 @@ class ViltModel(ViltPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds=None, image_embeds: Optional[torch.FloatTensor] = None,
image_token_type_idx=None, image_token_type_idx: Optional[int] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[BaseModelOutputWithPooling, Tuple[torch.FloatTensor]]:
r""" r"""
Returns: Returns:
...@@ -914,19 +914,19 @@ class ViltForMaskedLM(ViltPreTrainedModel): ...@@ -914,19 +914,19 @@ class ViltForMaskedLM(ViltPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds=None, image_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
...@@ -1088,19 +1088,19 @@ class ViltForQuestionAnswering(ViltPreTrainedModel): ...@@ -1088,19 +1088,19 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds=None, image_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*): labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
...@@ -1193,19 +1193,19 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel): ...@@ -1193,19 +1193,19 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds=None, image_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels are currently not supported. Labels are currently not supported.
...@@ -1299,19 +1299,19 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel): ...@@ -1299,19 +1299,19 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
@replace_return_docstrings(output_type=ViltForImagesAndTextClassificationOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ViltForImagesAndTextClassificationOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds=None, image_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[ViltForImagesAndTextClassificationOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Binary classification labels. Binary classification labels.
...@@ -1436,19 +1436,19 @@ class ViltForTokenClassification(ViltPreTrainedModel): ...@@ -1436,19 +1436,19 @@ class ViltForTokenClassification(ViltPreTrainedModel):
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask=None, pixel_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds=None, image_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[TokenClassifierOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment