Unverified Commit a23a7c0c authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Add flaubert types (#16118)

* Add type hints for FlauBERT PyTorch Base model. Others downstream tasks are inherited from XLM RoBERTa.

* Add type hints for FlaubERT Tensorflow models.

* fix output for TFFlaubertWithLMHeadModel
parent 366c18f4
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import random import random
from typing import Dict, Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
...@@ -153,19 +154,19 @@ class FlaubertModel(XLMModel): ...@@ -153,19 +154,19 @@ class FlaubertModel(XLMModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
langs=None, langs: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
lengths=None, lengths: Optional[torch.LongTensor] = None,
cache=None, cache: Optional[Dict[str, torch.FloatTensor]] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......
...@@ -20,8 +20,9 @@ import itertools ...@@ -20,8 +20,9 @@ import itertools
import random import random
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -243,21 +244,21 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel): ...@@ -243,21 +244,21 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
langs=None, langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
lengths=None, lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
cache=None, cache: Optional[Dict[str, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFBaseModelOutput]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -492,21 +493,21 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): ...@@ -492,21 +493,21 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
input_ids=None, input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
langs=None, langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
lengths=None, lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
cache=None, cache: Optional[Dict[str, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFBaseModelOutput]:
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
...@@ -827,21 +828,21 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): ...@@ -827,21 +828,21 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
langs=None, langs: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
lengths=None, lengths: Optional[Union[np.ndarray, tf.Tensor]] = None,
cache=None, cache: Optional[Dict[str, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[Tuple, TFFlaubertWithLMHeadModelOutput]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment