Unverified Commit 886b6be0 authored by David Reguera's avatar David Reguera Committed by GitHub
Browse files

Add type hints for several pytorch models (batch-4) (#25749)



* Add type hints for MGP STR model

* Add missing type hints for plbart model

* Add type hints for Pix2struct model

* Add missing type hints to Rag model and tweak the docstring

* Add missing type hints to Sam model

* Add missing type hints to Swin2sr model

* Fix a type hint for Pix2StructTextModel
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Fix typo on Rag model docstring
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Fix linter

---------
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent ed915cff
...@@ -380,7 +380,13 @@ class MgpstrModel(MgpstrPreTrainedModel): ...@@ -380,7 +380,13 @@ class MgpstrModel(MgpstrPreTrainedModel):
return self.embeddings.proj return self.embeddings.proj
@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -437,12 +443,12 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel): ...@@ -437,12 +443,12 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
@replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig) @replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
output_attentions=None, output_attentions: Optional[bool] = None,
output_a3_attentions=None, output_a3_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]:
r""" r"""
output_a3_attentions (`bool`, *optional*): output_a3_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors
......
...@@ -1387,21 +1387,21 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ...@@ -1387,21 +1387,21 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
r""" r"""
Returns: Returns:
......
...@@ -1177,7 +1177,7 @@ class PLBartModel(PLBartPreTrainedModel): ...@@ -1177,7 +1177,7 @@ class PLBartModel(PLBartPreTrainedModel):
encoder_outputs: Optional[List[torch.FloatTensor]] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1302,7 +1302,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1302,7 +1302,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
encoder_outputs: Optional[List[torch.FloatTensor]] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
......
...@@ -462,16 +462,12 @@ RAG_FORWARD_INPUTS_DOCSTRING = r""" ...@@ -462,16 +462,12 @@ RAG_FORWARD_INPUTS_DOCSTRING = r"""
`question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
(`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
returned when *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
question encoder `input_ids` by the retriever.
If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). `past_key_values`).
...@@ -545,7 +541,7 @@ class RagModel(RagPreTrainedModel): ...@@ -545,7 +541,7 @@ class RagModel(RagPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
doc_scores: Optional[torch.FloatTensor] = None, doc_scores: Optional[torch.FloatTensor] = None,
context_input_ids: Optional[torch.LongTensor] = None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None, context_attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
......
...@@ -1296,7 +1296,7 @@ class SamModel(SamPreTrainedModel): ...@@ -1296,7 +1296,7 @@ class SamModel(SamPreTrainedModel):
target_embedding: Optional[torch.FloatTensor] = None, target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
) -> List[Dict[str, torch.Tensor]]: ) -> List[Dict[str, torch.Tensor]]:
r""" r"""
......
...@@ -903,7 +903,7 @@ class Swin2SRModel(Swin2SRPreTrainedModel): ...@@ -903,7 +903,7 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment