Unverified Commit 460f36d3 authored by Jack McDonald's avatar Jack McDonald Committed by GitHub
Browse files

Add type hints transfoxl (#16267)

* Add type hint for pt transfo_xl model

* Add type hint for tf transfo_xl model
parent 2afe9cd2
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...file_utils import ( from ...file_utils import (
...@@ -29,6 +30,7 @@ from ...file_utils import ( ...@@ -29,6 +30,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
get_initializer, get_initializer,
...@@ -1077,17 +1079,17 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc ...@@ -1077,17 +1079,17 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
mems=None, mems: Optional[List[tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFTransfoXLSequenceClassifierOutputWithPast]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
""" """
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -1215,15 +1215,15 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): ...@@ -1215,15 +1215,15 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
mems=None, mems: Optional[List[torch.FloatTensor]] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TransfoXLSequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment