Unverified Commit 4d9e45f3 authored by David Reguera's avatar David Reguera Committed by GitHub
Browse files

Add type hints for several pytorch models (batch-3) (#25705)

* Add missing type hints for ErnieM family

* Add missing type hints for EsmForProteinFolding model

* Add missing type hints for Graphormer model

* Add type hints for InstructBlipQFormer model

* Add missing type hints for LayoutLMForMaskedLM model

* Add missing type hints for LukeForEntitySpanClassification model
parent 8b0a7bfc
......@@ -538,7 +538,7 @@ class ErnieMModel(ErnieMPreTrainedModel):
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
......@@ -647,7 +647,7 @@ class ErnieMForSequenceClassification(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = True,
labels: Optional[torch.Tensor] = None,
):
) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -744,7 +744,7 @@ class ErnieMForMultipleChoice(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
):
) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......@@ -837,7 +837,7 @@ class ErnieMForTokenClassification(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = True,
labels: Optional[torch.Tensor] = None,
):
) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......@@ -914,7 +914,7 @@ class ErnieMForQuestionAnswering(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
):
) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......@@ -1003,7 +1003,7 @@ class ErnieMForInformationExtraction(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
):
) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for position (index) for computing the start_positions loss. Position outside of the sequence are
......
......@@ -2086,11 +2086,11 @@ class EsmForProteinFolding(EsmPreTrainedModel):
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
):
) -> EsmForProteinFoldingOutput:
r"""
Returns:
......
......@@ -816,8 +816,8 @@ class GraphormerModel(GraphormerPreTrainedModel):
out_degree: torch.LongTensor,
spatial_pos: torch.LongTensor,
attn_edge_type: torch.LongTensor,
perturb=None,
masked_tokens=None,
perturb: Optional[torch.FloatTensor] = None,
masked_tokens: None = None,
return_dict: Optional[bool] = None,
**unused,
) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
......
......@@ -1124,19 +1124,19 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
def forward(
self,
input_ids,
attention_mask=None,
position_ids=None,
query_embeds=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: torch.LongTensor,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
query_embeds: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......
......@@ -891,8 +891,8 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
......
......@@ -1664,7 +1664,7 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask=None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment