"tools/vscode:/vscode.git/clone" did not exist on "6623dff3d7cc1831a19edd2ab17b9e598d7b96bf"
Unverified Commit 460f36d3 authored by Jack McDonald's avatar Jack McDonald Committed by GitHub
Browse files

Add type hints transfoxl (#16267)

* Add type hint for pt transfo_xl model

* Add type hint for tf transfo_xl model
parent 2afe9cd2
......@@ -18,8 +18,9 @@
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...file_utils import (
......@@ -29,6 +30,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
)
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
......@@ -1077,17 +1079,17 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
)
def call(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
mems: Optional[List[tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[Tuple, TFTransfoXLSequenceClassifierOutputWithPast]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
......
......@@ -19,7 +19,7 @@
"""
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -1215,15 +1215,15 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
)
def forward(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.Tensor] = None,
mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TransfoXLSequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment