"tests/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "a8a9b2e55124074ef0e75690c8c30a126bb409c3"
Unverified Commit 4d9e45f3 authored by David Reguera's avatar David Reguera Committed by GitHub
Browse files

Add type hints for several pytorch models (batch-3) (#25705)

* Add missing type hints for ErnieM family

* Add missing type hints for EsmForProteinFolding model

* Add missing type hints for Graphormer model

* Add type hints for InstructBlipQFormer model

* Add missing type hints for LayoutLMForMaskedLM model

* Add missing type hints for LukeForEntitySpanClassification model
parent 8b0a7bfc
...@@ -538,7 +538,7 @@ class ErnieMModel(ErnieMPreTrainedModel): ...@@ -538,7 +538,7 @@ class ErnieMModel(ErnieMPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
...@@ -647,7 +647,7 @@ class ErnieMForSequenceClassification(ErnieMPreTrainedModel): ...@@ -647,7 +647,7 @@ class ErnieMForSequenceClassification(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
): ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -744,7 +744,7 @@ class ErnieMForMultipleChoice(ErnieMPreTrainedModel): ...@@ -744,7 +744,7 @@ class ErnieMForMultipleChoice(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -837,7 +837,7 @@ class ErnieMForTokenClassification(ErnieMPreTrainedModel): ...@@ -837,7 +837,7 @@ class ErnieMForTokenClassification(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
): ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -914,7 +914,7 @@ class ErnieMForQuestionAnswering(ErnieMPreTrainedModel): ...@@ -914,7 +914,7 @@ class ErnieMForQuestionAnswering(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -1003,7 +1003,7 @@ class ErnieMForInformationExtraction(ErnieMPreTrainedModel): ...@@ -1003,7 +1003,7 @@ class ErnieMForInformationExtraction(ErnieMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for position (index) for computing the start_positions loss. Position outside of the sequence are Labels for position (index) for computing the start_positions loss. Position outside of the sequence are
......
...@@ -2086,11 +2086,11 @@ class EsmForProteinFolding(EsmPreTrainedModel): ...@@ -2086,11 +2086,11 @@ class EsmForProteinFolding(EsmPreTrainedModel):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
attention_mask: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None, masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None, num_recycles: Optional[int] = None,
): ) -> EsmForProteinFoldingOutput:
r""" r"""
Returns: Returns:
......
...@@ -816,8 +816,8 @@ class GraphormerModel(GraphormerPreTrainedModel): ...@@ -816,8 +816,8 @@ class GraphormerModel(GraphormerPreTrainedModel):
out_degree: torch.LongTensor, out_degree: torch.LongTensor,
spatial_pos: torch.LongTensor, spatial_pos: torch.LongTensor,
attn_edge_type: torch.LongTensor, attn_edge_type: torch.LongTensor,
perturb=None, perturb: Optional[torch.FloatTensor] = None,
masked_tokens=None, masked_tokens: None = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**unused, **unused,
) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]: ) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
......
...@@ -1124,19 +1124,19 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel): ...@@ -1124,19 +1124,19 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
def forward( def forward(
self, self,
input_ids, input_ids: torch.LongTensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
query_embeds=None, query_embeds: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......
...@@ -891,8 +891,8 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): ...@@ -891,8 +891,8 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
......
...@@ -1664,7 +1664,7 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1664,7 +1664,7 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None, entity_ids: Optional[torch.LongTensor] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment