Unverified Commit 5bb211be authored by Tom Mathews's avatar Tom Mathews Committed by GitHub
Browse files

Adding type hints of TF:CTRL (#18264)

parent c8ed1b8b
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import warnings import warnings
from typing import Tuple from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -24,6 +24,7 @@ import tensorflow as tf ...@@ -24,6 +24,7 @@ import tensorflow as tf
from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings, TFSharedEmbeddings,
...@@ -256,19 +257,19 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -256,19 +257,19 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
): ) -> Union[Tuple, TFBaseModelOutputWithPast]:
# If using past key value states, only the last tokens # If using past key value states, only the last tokens
# should be given as an input # should be given as an input
...@@ -528,19 +529,19 @@ class TFCTRLModel(TFCTRLPreTrainedModel): ...@@ -528,19 +529,19 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
): ) -> Union[Tuple, TFBaseModelOutputWithPast]:
outputs = self.transformer( outputs = self.transformer(
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
...@@ -642,20 +643,20 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -642,20 +643,20 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
): ) -> Union[Tuple, TFCausalLMOutputWithPast]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
...@@ -753,20 +754,20 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific ...@@ -753,20 +754,20 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
): ) -> Union[Tuple, TFSequenceClassifierOutput]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment