Unverified Commit deb61e5f authored by Dan Tegzes's avatar Dan Tegzes Committed by GitHub
Browse files

Add type hints for Pegasus (#16324)

parent 7cc2c9c6
...@@ -1184,22 +1184,22 @@ class PegasusModel(PegasusPreTrainedModel): ...@@ -1184,22 +1184,22 @@ class PegasusModel(PegasusPreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.Tensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1352,23 +1352,23 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1352,23 +1352,23 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
@add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.Tensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment