Unverified Commit d438eee0 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Adding TFWav2Vec2Model (#11617)



* [WIP] Add TFWav2Vec2Model

Work in progress for adding a tensorflow version of Wav2Vec2

* feedback changes

* small fix

* Test Feedback Round 1

* Add SpecAugment and CTC Loss

* correct spec augment mask creation

* docstring and correct copyright

* correct bugs

* remove bogus file

* finish tests correction

* del unnecessary layers

* Update src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make style

* correct final bug

* Feedback Changes
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1ed2ebf6
......@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | | ❌ |
| Wav2Vec2 | ✅ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -80,9 +80,22 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward
Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForPreTraining
:members: forward
TFWav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFWav2Vec2Model
:members: call
TFWav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFWav2Vec2ForCTC
:members: call
......@@ -1430,6 +1430,14 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend(
[
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
)
_import_structure["models.xlm"].extend(
[
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -2743,6 +2751,12 @@ if TYPE_CHECKING:
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
from .models.xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice,
......
......@@ -37,6 +37,7 @@ from . import (
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -79,10 +80,13 @@ from . import (
TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration,
TFTransfoXLLMHeadModel,
TFWav2Vec2Model,
TFXLMRobertaForMaskedLM,
TFXLMWithLMHeadModel,
TFXLNetLMHeadModel,
TransfoXLConfig,
Wav2Vec2Config,
Wav2Vec2Model,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
......@@ -287,6 +291,12 @@ MODEL_CLASSES = {
ElectraForPreTraining,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
"wav2vec2": (
Wav2Vec2Config,
TFWav2Vec2Model,
Wav2Vec2Model,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
}
......
......@@ -163,6 +163,7 @@ from ..transfo_xl.modeling_tf_transfo_xl import (
TFTransfoXLLMHeadModel,
TFTransfoXLModel,
)
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
......@@ -218,6 +219,7 @@ from .configuration_auto import (
RoFormerConfig,
T5Config,
TransfoXLConfig,
Wav2Vec2Config,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
......@@ -263,6 +265,7 @@ TF_MODEL_MAPPING = OrderedDict(
(PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model),
]
)
......
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = {
......@@ -38,6 +38,15 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
......@@ -54,6 +63,14 @@ if TYPE_CHECKING:
Wav2Vec2PreTrainedModel,
)
if is_tf_available():
from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
else:
import importlib
......
This diff is collapsed.
......@@ -510,14 +510,6 @@ class Wav2Vec2FeedForward(nn.Module):
return hidden_states
class Wav2Vec2Output(nn.Module):
def __init__(self, config):
super().__init__()
def forward(self, hidden_states, input_tensor):
return hidden_states
class Wav2Vec2EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
......
......@@ -1647,6 +1647,32 @@ class TFTransfoXLPreTrainedModel:
requires_backends(cls, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFWav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFWav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFWav2Vec2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -445,6 +445,8 @@ class TFModelTesterMixin:
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
......@@ -455,6 +457,7 @@ class TFModelTesterMixin:
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
......@@ -486,6 +489,8 @@ class TFModelTesterMixin:
if type(key) == bool:
key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
......@@ -1061,7 +1066,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
......@@ -1097,7 +1102,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
input_ids = inputs_dict.get("input_ids", None)
for model_class in self.all_generative_model_classes:
model = model_class(config)
......
This diff is collapsed.
......@@ -126,6 +126,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"VisualBertForVisualReasoning",
"VisualBertForQuestionAnswering",
"VisualBertForMultipleChoice",
"TFWav2Vec2ForCTC",
]
# This is to make sure the transformers module imported is the one in the repo.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment