Unverified Commit 88f7c564 authored by Dan Tegzes's avatar Dan Tegzes Committed by GitHub
Browse files

Added type hints for Reformer (#16175)

parent 16399d61
......@@ -20,7 +20,7 @@ from collections import namedtuple
from dataclasses import dataclass
from functools import reduce
from operator import mul
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -1995,18 +1995,18 @@ class ReformerModel(ReformerPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
past_buckets_states=None,
use_cache=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ReformerModelOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -2202,19 +2202,19 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
past_buckets_states=None,
use_cache=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
labels=None,
):
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
......@@ -2318,17 +2318,17 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
labels=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -2400,17 +2400,17 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
labels=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -2519,18 +2519,18 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
)
def forward(
self,
input_ids=None,
position_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
num_hashes=None,
start_positions=None,
end_positions=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
num_hashes: Optional[int] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment