"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "02b176c4ce14340d26d42825523f406959c6c202"
Unverified Commit fb65f65e authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Add TFHubertModel (#12206)

* TFHubert

* Update with TFWav2Vec Bug Fixes

* Add OOV Error

* Feedback changes

* Fix kwargs call
parent 934222e3
......@@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Hubert | ❌ | ❌ | ✅ | | ❌ |
| Hubert | ❌ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -63,3 +63,16 @@ HubertForCTC
.. autoclass:: transformers.HubertForCTC
:members: forward
TFHubertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFHubertModel
:members: call
TFHubertForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFHubertForCTC
:members: call
......@@ -1343,6 +1343,14 @@ if is_tf_available():
"TFGPT2PreTrainedModel",
]
)
_import_structure["models.hubert"].extend(
[
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFHubertForCTC",
"TFHubertModel",
"TFHubertPreTrainedModel",
]
)
_import_structure["models.layoutlm"].extend(
[
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -2804,6 +2812,12 @@ if TYPE_CHECKING:
TFGPT2Model,
TFGPT2PreTrainedModel,
)
from .models.hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
TFHubertModel,
TFHubertPreTrainedModel,
)
from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -101,6 +101,7 @@ from ..funnel.modeling_tf_funnel import (
TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..hubert.modeling_tf_hubert import TFHubertModel
from ..layoutlm.modeling_tf_layoutlm import (
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
......@@ -204,6 +205,7 @@ from .configuration_auto import (
FlaubertConfig,
FunnelConfig,
GPT2Config,
HubertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
......@@ -266,6 +268,7 @@ TF_MODEL_MAPPING = OrderedDict(
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model),
(HubertConfig, TFHubertModel),
]
)
......
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_torch_available
_import_structure = {
......@@ -33,6 +33,14 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_hubert"] = [
"TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFHubertForCTC",
"TFHubertModel",
"TFHubertPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
......@@ -44,6 +52,14 @@ if TYPE_CHECKING:
HubertPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_hubert import (
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFHubertForCTC,
TFHubertModel,
TFHubertPreTrainedModel,
)
else:
import sys
......
This diff is collapsed.
......@@ -83,13 +83,6 @@ def input_values_processing(func, config, input_values, **kwargs):
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
kwargs.pop("kwargs_call")
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
......@@ -1398,7 +1391,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs: Any,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
"""
......@@ -1438,7 +1430,6 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
inputs["output_hidden_states"] = (
......@@ -1505,7 +1496,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs: Any,
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
......@@ -1561,7 +1551,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.wav2vec2(
......
......@@ -1014,6 +1014,32 @@ class TFGPT2PreTrainedModel:
requires_backends(cls, ["tf"])
TF_HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFHubertForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFHubertModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFHubertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFLEDForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
......
This diff is collapsed.
......@@ -130,6 +130,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"VisualBertForQuestionAnswering",
"VisualBertForMultipleChoice",
"TFWav2Vec2ForCTC",
"TFHubertForCTC",
]
# This is to make sure the transformers module imported is the one in the repo.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment