"vscode:/vscode.git/clone" did not exist on "98d40fed3a4515077163adab9dfd8fb2fccf1267"
Unverified Commit fb65f65e authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Add TFHubertModel (#12206)

* TFHubert

* Update with TFWav2Vec Bug Fixes

* Add OOV Error

* Feedback changes

* Fix kwargs call
parent 934222e3
...@@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Hubert | ❌ | ❌ | ✅ | | ❌ | | Hubert | ❌ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -63,3 +63,16 @@ HubertForCTC ...@@ -63,3 +63,16 @@ HubertForCTC
.. autoclass:: transformers.HubertForCTC .. autoclass:: transformers.HubertForCTC
:members: forward :members: forward
TFHubertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFHubertModel
:members: call
TFHubertForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFHubertForCTC
:members: call
...@@ -1343,6 +1343,14 @@ if is_tf_available(): ...@@ -1343,6 +1343,14 @@ if is_tf_available():
"TFGPT2PreTrainedModel", "TFGPT2PreTrainedModel",
] ]
) )
_import_structure["models.hubert"].extend(
[
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFHubertForCTC",
"TFHubertModel",
"TFHubertPreTrainedModel",
]
)
_import_structure["models.layoutlm"].extend( _import_structure["models.layoutlm"].extend(
[ [
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2804,6 +2812,12 @@ if TYPE_CHECKING: ...@@ -2804,6 +2812,12 @@ if TYPE_CHECKING:
TFGPT2Model, TFGPT2Model,
TFGPT2PreTrainedModel, TFGPT2PreTrainedModel,
) )
from .models.hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
TFHubertModel,
TFHubertPreTrainedModel,
)
from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
from .models.longformer import ( from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -101,6 +101,7 @@ from ..funnel.modeling_tf_funnel import ( ...@@ -101,6 +101,7 @@ from ..funnel.modeling_tf_funnel import (
TFFunnelModel, TFFunnelModel,
) )
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..hubert.modeling_tf_hubert import TFHubertModel
from ..layoutlm.modeling_tf_layoutlm import ( from ..layoutlm.modeling_tf_layoutlm import (
TFLayoutLMForMaskedLM, TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification, TFLayoutLMForSequenceClassification,
...@@ -204,6 +205,7 @@ from .configuration_auto import ( ...@@ -204,6 +205,7 @@ from .configuration_auto import (
FlaubertConfig, FlaubertConfig,
FunnelConfig, FunnelConfig,
GPT2Config, GPT2Config,
HubertConfig,
LayoutLMConfig, LayoutLMConfig,
LEDConfig, LEDConfig,
LongformerConfig, LongformerConfig,
...@@ -266,6 +268,7 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -266,6 +268,7 @@ TF_MODEL_MAPPING = OrderedDict(
(BlenderbotConfig, TFBlenderbotModel), (BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel), (BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model), (Wav2Vec2Config, TFWav2Vec2Model),
(HubertConfig, TFHubertModel),
] ]
) )
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available from ...file_utils import _LazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -33,6 +33,14 @@ if is_torch_available(): ...@@ -33,6 +33,14 @@ if is_torch_available():
] ]
if is_tf_available():
_import_structure["modeling_tf_hubert"] = [
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFHubertForCTC",
"TFHubertModel",
"TFHubertPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
...@@ -44,6 +52,14 @@ if TYPE_CHECKING: ...@@ -44,6 +52,14 @@ if TYPE_CHECKING:
HubertPreTrainedModel, HubertPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
TFHubertModel,
TFHubertPreTrainedModel,
)
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -83,13 +83,6 @@ def input_values_processing(func, config, input_values, **kwargs): ...@@ -83,13 +83,6 @@ def input_values_processing(func, config, input_values, **kwargs):
output = {} output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
kwargs.pop("kwargs_call")
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None: if isinstance(v, allowed_types) or v is None:
output[k] = v output[k] = v
...@@ -1398,7 +1391,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): ...@@ -1398,7 +1391,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
**kwargs: Any,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
""" """
...@@ -1438,7 +1430,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): ...@@ -1438,7 +1430,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
) )
inputs["output_hidden_states"] = ( inputs["output_hidden_states"] = (
...@@ -1505,7 +1496,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1505,7 +1496,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: Optional[bool] = False, training: Optional[bool] = False,
**kwargs: Any,
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
...@@ -1561,7 +1551,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1561,7 +1551,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
) )
outputs = self.wav2vec2( outputs = self.wav2vec2(
......
...@@ -1014,6 +1014,32 @@ class TFGPT2PreTrainedModel: ...@@ -1014,6 +1014,32 @@ class TFGPT2PreTrainedModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFHubertForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFHubertModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFHubertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFLEDForConditionalGeneration: class TFLEDForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
......
This diff is collapsed.
...@@ -130,6 +130,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -130,6 +130,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"VisualBertForQuestionAnswering", "VisualBertForQuestionAnswering",
"VisualBertForMultipleChoice", "VisualBertForMultipleChoice",
"TFWav2Vec2ForCTC", "TFWav2Vec2ForCTC",
"TFHubertForCTC",
] ]
# This is to make sure the transformers module imported is the one in the repo. # This is to make sure the transformers module imported is the one in the repo.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment