Unverified Commit 46d09410 authored by Ian Castillo's avatar Ian Castillo Committed by GitHub
Browse files

Add type hints for ViLT models (#18577)

* Add type hints for Vilt models

* Add missing return type for TokenClassification class
parent bce36ee0
......@@ -17,7 +17,7 @@
import collections.abc
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -761,19 +761,19 @@ class ViltModel(ViltPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
image_token_type_idx=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
image_token_type_idx: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[BaseModelOutputWithPooling, Tuple[torch.FloatTensor]]:
r"""
Returns:
......@@ -914,19 +914,19 @@ class ViltForMaskedLM(ViltPreTrainedModel):
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
r"""
labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
......@@ -1088,19 +1088,19 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
......@@ -1193,19 +1193,19 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels are currently not supported.
......@@ -1299,19 +1299,19 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
@replace_return_docstrings(output_type=ViltForImagesAndTextClassificationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[ViltForImagesAndTextClassificationOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Binary classification labels.
......@@ -1436,19 +1436,19 @@ class ViltForTokenClassification(ViltPreTrainedModel):
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
pixel_values=None,
pixel_mask=None,
head_mask=None,
inputs_embeds=None,
image_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[TokenClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment