Unverified Commit 0dcdfe86 authored by Martin Pan's avatar Martin Pan Committed by GitHub
Browse files

Add type hints for FNet PyTorch (#16123)

parent f86235ad
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -644,15 +644,15 @@ class FNetForPreTraining(FNetPreTrainedModel): ...@@ -644,15 +644,15 @@ class FNetForPreTraining(FNetPreTrainedModel):
@replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
next_sentence_label=None, next_sentence_label: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, FNetForPreTrainingOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -741,14 +741,14 @@ class FNetForMaskedLM(FNetPreTrainedModel): ...@@ -741,14 +741,14 @@ class FNetForMaskedLM(FNetPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -799,15 +799,15 @@ class FNetForNextSentencePrediction(FNetPreTrainedModel): ...@@ -799,15 +799,15 @@ class FNetForNextSentencePrediction(FNetPreTrainedModel):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple, NextSentencePredictorOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
...@@ -900,14 +900,14 @@ class FNetForSequenceClassification(FNetPreTrainedModel): ...@@ -900,14 +900,14 @@ class FNetForSequenceClassification(FNetPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -985,14 +985,14 @@ class FNetForMultipleChoice(FNetPreTrainedModel): ...@@ -985,14 +985,14 @@ class FNetForMultipleChoice(FNetPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1067,14 +1067,14 @@ class FNetForTokenClassification(FNetPreTrainedModel): ...@@ -1067,14 +1067,14 @@ class FNetForTokenClassification(FNetPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1136,15 +1136,15 @@ class FNetForQuestionAnswering(FNetPreTrainedModel): ...@@ -1136,15 +1136,15 @@ class FNetForQuestionAnswering(FNetPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
start_positions=None, start_positions: Optional[torch.Tensor] = None,
end_positions=None, end_positions: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment