Unverified Commit 70a9bc69 authored by Yi Heng Lim's avatar Yi Heng Lim Committed by GitHub
Browse files

Added type hints (#16389)

* Added type hints for PyTorch T5 model

* removed a type hint

* ran make style

* added type hints for ibert pytorch

* added type hints for lxmert pytorch

* removed kwargs type hint and fixed arguments order
parent cae394c8
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""PyTorch I-BERT model.""" """PyTorch I-BERT model."""
import math import math
from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -777,16 +778,16 @@ class IBertModel(IBertPreTrainedModel): ...@@ -777,16 +778,16 @@ class IBertModel(IBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.FloatTensor]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -882,17 +883,18 @@ class IBertForMaskedLM(IBertPreTrainedModel): ...@@ -882,17 +883,18 @@ class IBertForMaskedLM(IBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): **kwargs,
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -990,17 +992,17 @@ class IBertForSequenceClassification(IBertPreTrainedModel): ...@@ -990,17 +992,17 @@ class IBertForSequenceClassification(IBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1086,17 +1088,17 @@ class IBertForMultipleChoice(IBertPreTrainedModel): ...@@ -1086,17 +1088,17 @@ class IBertForMultipleChoice(IBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[MultipleChoiceModelOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1181,17 +1183,17 @@ class IBertForTokenClassification(IBertPreTrainedModel): ...@@ -1181,17 +1183,17 @@ class IBertForTokenClassification(IBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[TokenClassifierOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1281,18 +1283,18 @@ class IBertForQuestionAnswering(IBertPreTrainedModel): ...@@ -1281,18 +1283,18 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.FloatTensor]]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -19,7 +19,7 @@ import math ...@@ -19,7 +19,7 @@ import math
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Dict, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -180,7 +180,7 @@ class LxmertForPreTrainingOutput(ModelOutput): ...@@ -180,7 +180,7 @@ class LxmertForPreTrainingOutput(ModelOutput):
""" """
loss: [torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
prediction_logits: Optional[torch.FloatTensor] = None prediction_logits: Optional[torch.FloatTensor] = None
cross_relationship_score: Optional[torch.FloatTensor] = None cross_relationship_score: Optional[torch.FloatTensor] = None
question_answering_score: Optional[torch.FloatTensor] = None question_answering_score: Optional[torch.FloatTensor] = None
...@@ -907,17 +907,17 @@ class LxmertModel(LxmertPreTrainedModel): ...@@ -907,17 +907,17 @@ class LxmertModel(LxmertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
visual_feats=None, visual_feats: Optional[torch.FloatTensor] = None,
visual_pos=None, visual_pos: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
visual_attention_mask=None, visual_attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[LxmertModelOutput, Tuple[torch.FloatTensor]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1154,22 +1154,22 @@ class LxmertForPreTraining(LxmertPreTrainedModel): ...@@ -1154,22 +1154,22 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
@replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
visual_feats=None, visual_feats: Optional[torch.FloatTensor] = None,
visual_pos=None, visual_pos: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
visual_attention_mask=None, visual_attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
obj_labels=None, obj_labels: Optional[Dict[str, Tuple[torch.FloatTensor, torch.FloatTensor]]] = None,
matched_label=None, matched_label: Optional[torch.LongTensor] = None,
ans=None, ans: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[LxmertForPreTrainingOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1390,18 +1390,18 @@ class LxmertForQuestionAnswering(LxmertPreTrainedModel): ...@@ -1390,18 +1390,18 @@ class LxmertForQuestionAnswering(LxmertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
visual_feats=None, visual_feats: Optional[torch.FloatTensor] = None,
visual_pos=None, visual_pos: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
visual_attention_mask=None, visual_attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[LxmertForQuestionAnsweringOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels: (`Torch.Tensor` of shape `(batch_size)`, *optional*): labels: (`Torch.Tensor` of shape `(batch_size)`, *optional*):
A one-hot representation of the correct answer A one-hot representation of the correct answer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment