Unverified Commit 41ec5d0c authored by Thomas's avatar Thomas Committed by GitHub
Browse files

Added type hints for TF: TransfoXL (#19380)

* Added type hints for TF: TransfoXL
* Added type hints for TF: TransfoXL

* Change type hints for training

* Change type hints for training
parent b29ebdf4
......@@ -542,14 +542,15 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
@unpack_inputs
def call(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
mems: Optional[List[tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: bool = False,
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
......@@ -894,14 +895,14 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
)
def call(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
mems: Optional[List[tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
):
outputs = self.transformer(
input_ids=input_ids,
......@@ -974,15 +975,15 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
)
def call(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
mems: Optional[List[tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: bool = False,
):
if input_ids is not None:
bsz, tgt_len = shape_list(input_ids)[:2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment