Unverified Commit d438eee0 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Adding TFWav2Vec2Model (#11617)



* [WIP] Add TFWav2Vec2Model

Work in progress for adding a tensorflow version of Wav2Vec2

* feedback changes

* small fix

* Test Feedback Round 1

* Add SpecAugment and CTC Loss

* correct spec augment mask creation

* docstring and correct copyright

* correct bugs

* remove bogus file

* finish tests correction

* del unnecessary layers

* Update src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make style

* correct final bug

* Feedback Changes
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1ed2ebf6
...@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | | ❌ | | Wav2Vec2 | ✅ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -80,9 +80,22 @@ Wav2Vec2ForCTC ...@@ -80,9 +80,22 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward :members: forward
Wav2Vec2ForPreTraining Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForPreTraining .. autoclass:: transformers.Wav2Vec2ForPreTraining
:members: forward :members: forward
TFWav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFWav2Vec2Model
:members: call
TFWav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFWav2Vec2ForCTC
:members: call
...@@ -1430,6 +1430,14 @@ if is_tf_available(): ...@@ -1430,6 +1430,14 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel", "TFTransfoXLPreTrainedModel",
] ]
) )
_import_structure["models.wav2vec2"].extend(
[
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
)
_import_structure["models.xlm"].extend( _import_structure["models.xlm"].extend(
[ [
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2743,6 +2751,12 @@ if TYPE_CHECKING: ...@@ -2743,6 +2751,12 @@ if TYPE_CHECKING:
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLPreTrainedModel, TFTransfoXLPreTrainedModel,
) )
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
from .models.xlm import ( from .models.xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
......
...@@ -37,6 +37,7 @@ from . import ( ...@@ -37,6 +37,7 @@ from . import (
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME, WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -79,10 +80,13 @@ from . import ( ...@@ -79,10 +80,13 @@ from . import (
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration, TFT5ForConditionalGeneration,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TFWav2Vec2Model,
TFXLMRobertaForMaskedLM, TFXLMRobertaForMaskedLM,
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
TFXLNetLMHeadModel, TFXLNetLMHeadModel,
TransfoXLConfig, TransfoXLConfig,
Wav2Vec2Config,
Wav2Vec2Model,
XLMConfig, XLMConfig,
XLMRobertaConfig, XLMRobertaConfig,
XLNetConfig, XLNetConfig,
...@@ -287,6 +291,12 @@ MODEL_CLASSES = { ...@@ -287,6 +291,12 @@ MODEL_CLASSES = {
ElectraForPreTraining, ElectraForPreTraining,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
), ),
"wav2vec2": (
Wav2Vec2Config,
TFWav2Vec2Model,
Wav2Vec2Model,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
} }
......
...@@ -163,6 +163,7 @@ from ..transfo_xl.modeling_tf_transfo_xl import ( ...@@ -163,6 +163,7 @@ from ..transfo_xl.modeling_tf_transfo_xl import (
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TFTransfoXLModel, TFTransfoXLModel,
) )
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
from ..xlm.modeling_tf_xlm import ( from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
...@@ -218,6 +219,7 @@ from .configuration_auto import ( ...@@ -218,6 +219,7 @@ from .configuration_auto import (
RoFormerConfig, RoFormerConfig,
T5Config, T5Config,
TransfoXLConfig, TransfoXLConfig,
Wav2Vec2Config,
XLMConfig, XLMConfig,
XLMRobertaConfig, XLMRobertaConfig,
XLNetConfig, XLNetConfig,
...@@ -263,6 +265,7 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -263,6 +265,7 @@ TF_MODEL_MAPPING = OrderedDict(
(PegasusConfig, TFPegasusModel), (PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel), (BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel), (BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model),
] ]
) )
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -38,6 +38,15 @@ if is_torch_available(): ...@@ -38,6 +38,15 @@ if is_torch_available():
] ]
if is_tf_available():
_import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
...@@ -54,6 +63,14 @@ if TYPE_CHECKING: ...@@ -54,6 +63,14 @@ if TYPE_CHECKING:
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
else: else:
import importlib import importlib
......
This diff is collapsed.
...@@ -510,14 +510,6 @@ class Wav2Vec2FeedForward(nn.Module): ...@@ -510,14 +510,6 @@ class Wav2Vec2FeedForward(nn.Module):
return hidden_states return hidden_states
class Wav2Vec2Output(nn.Module):
def __init__(self, config):
super().__init__()
def forward(self, hidden_states, input_tensor):
return hidden_states
class Wav2Vec2EncoderLayer(nn.Module): class Wav2Vec2EncoderLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
......
...@@ -1647,6 +1647,32 @@ class TFTransfoXLPreTrainedModel: ...@@ -1647,6 +1647,32 @@ class TFTransfoXLPreTrainedModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFWav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFWav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFWav2Vec2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -445,6 +445,8 @@ class TFModelTesterMixin: ...@@ -445,6 +445,8 @@ class TFModelTesterMixin:
for name, key in self._prepare_for_class(inputs_dict, model_class).items(): for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool: if type(key) == bool:
pt_inputs_dict[name] = key pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
...@@ -455,6 +457,7 @@ class TFModelTesterMixin: ...@@ -455,6 +457,7 @@ class TFModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
pto = pt_model(**pt_inputs_dict) pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
tf_hidden_states = tfo[0].numpy() tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy() pt_hidden_states = pto[0].numpy()
...@@ -486,6 +489,8 @@ class TFModelTesterMixin: ...@@ -486,6 +489,8 @@ class TFModelTesterMixin:
if type(key) == bool: if type(key) == bool:
key = np.array(key, dtype=bool) key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch # need to rename encoder-decoder "inputs" for PyTorch
...@@ -1061,7 +1066,7 @@ class TFModelTesterMixin: ...@@ -1061,7 +1066,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models # iterate over all generative models
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -1097,7 +1102,7 @@ class TFModelTesterMixin: ...@@ -1097,7 +1102,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self): def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict.get("input_ids", None)
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
model = model_class(config) model = model_class(config)
......
This diff is collapsed.
...@@ -126,6 +126,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -126,6 +126,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"VisualBertForVisualReasoning", "VisualBertForVisualReasoning",
"VisualBertForQuestionAnswering", "VisualBertForQuestionAnswering",
"VisualBertForMultipleChoice", "VisualBertForMultipleChoice",
"TFWav2Vec2ForCTC",
] ]
# This is to make sure the transformers module imported is the one in the repo. # This is to make sure the transformers module imported is the one in the repo.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment