Unverified Commit 41ec5d0c authored by Thomas's avatar Thomas Committed by GitHub
Browse files

Added type hints for TF: TransfoXL (#19380)

* Added type hints for TF: TransfoXL
* Added type hints for TF: TransfoXL

* Change type hints for training

* Change type hints for training
parent b29ebdf4
...@@ -542,14 +542,15 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -542,14 +542,15 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
mems=None, mems: Optional[List[tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: bool = False,
): ):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
...@@ -894,14 +895,14 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel): ...@@ -894,14 +895,14 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
mems=None, mems: Optional[List[tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
): ):
outputs = self.transformer( outputs = self.transformer(
input_ids=input_ids, input_ids=input_ids,
...@@ -974,15 +975,15 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -974,15 +975,15 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
mems=None, mems: Optional[List[tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: bool = False,
): ):
if input_ids is not None: if input_ids is not None:
bsz, tgt_len = shape_list(input_ids)[:2] bsz, tgt_len = shape_list(input_ids)[:2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment