Unverified Commit 87a9af53 authored by Dan Tegzes's avatar Dan Tegzes Committed by GitHub
Browse files

Add type hints for ProphetNet PyTorch (#16272)

parent 7b262b96
...@@ -18,7 +18,7 @@ import copy ...@@ -18,7 +18,7 @@ import copy
import math import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -1275,14 +1275,14 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1275,14 +1275,14 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1422,19 +1422,19 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1422,19 +1422,19 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
@replace_return_docstrings(output_type=ProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, ProphetNetDecoderModelOutput]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -1784,22 +1784,22 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1784,22 +1784,22 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
@replace_return_docstrings(output_type=ProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.Tensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, ProphetNetSeq2SeqModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1900,23 +1900,23 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1900,23 +1900,23 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
@replace_return_docstrings(output_type=ProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.Tensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, ProphetNetSeq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
...@@ -2123,20 +2123,20 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2123,20 +2123,20 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, ProphetNetDecoderLMOutput]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment