"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c02421883b2a59e075ec87de8d82f02a944fb5e8"
Unverified Commit 3f8360a7 authored by Pepijn Boers's avatar Pepijn Boers Committed by GitHub
Browse files

Add type hints for TFDistilBert (#16107)



* Add type hints for TFDistilBert

* Update src/transformers/models/distilbert/modeling_tf_distilbert.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 97e32b78
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
""" """
import warnings import warnings
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -37,6 +39,7 @@ from ...modeling_tf_outputs import ( ...@@ -37,6 +39,7 @@ from ...modeling_tf_outputs import (
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
...@@ -546,16 +549,16 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel): ...@@ -546,16 +549,16 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -661,17 +664,17 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel ...@@ -661,17 +664,17 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -762,17 +765,17 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque ...@@ -762,17 +765,17 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -857,17 +860,17 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla ...@@ -857,17 +860,17 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -964,17 +967,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic ...@@ -964,17 +967,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
...@@ -1089,18 +1092,18 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn ...@@ -1089,18 +1092,18 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
start_positions=None, start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions=None, end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r""" r"""
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment