Unverified Commit 88f7c564 authored by Dan Tegzes's avatar Dan Tegzes Committed by GitHub
Browse files

Added type hints for Reformer (#16175)

parent 16399d61
...@@ -20,7 +20,7 @@ from collections import namedtuple ...@@ -20,7 +20,7 @@ from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce from functools import reduce
from operator import mul from operator import mul
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -1995,18 +1995,18 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1995,18 +1995,18 @@ class ReformerModel(ReformerPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
num_hashes=None, num_hashes: Optional[int] = None,
past_buckets_states=None, past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, ReformerModelOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -2202,19 +2202,19 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2202,19 +2202,19 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
num_hashes=None, num_hashes: Optional[int] = None,
past_buckets_states=None, past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[torch.Tensor] = None,
): ) -> Union[Tuple, CausalLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
...@@ -2318,17 +2318,17 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): ...@@ -2318,17 +2318,17 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
num_hashes=None, num_hashes: Optional[int] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -2400,17 +2400,17 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2400,17 +2400,17 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
num_hashes=None, num_hashes: Optional[int] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -2519,18 +2519,18 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): ...@@ -2519,18 +2519,18 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
num_hashes=None, num_hashes: Optional[int] = None,
start_positions=None, start_positions: Optional[torch.Tensor] = None,
end_positions=None, end_positions: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment