Unverified Commit 20ac86c6 authored by Ritik Nandwal's avatar Ritik Nandwal Committed by GitHub
Browse files

Add TensorFlow Wav2Vec2 for sequence classification (#22073)

* Add initial changes for TF wav2vec2 for sequence classification

* Add suggested changes

* Add serving and serving output methods

* Add serving_output implementation and fix layer_weights

* Add fixes

* Fixed test cases

* Fixing test and adding suggested changes
parent 4c2b4c4c
...@@ -197,6 +197,11 @@ Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower ...@@ -197,6 +197,11 @@ Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower
[[autodoc]] TFWav2Vec2Model [[autodoc]] TFWav2Vec2Model
- call - call
## TFWav2Vec2ForSequenceClassification
[[autodoc]] TFWav2Vec2ForSequenceClassification
- call
## TFWav2Vec2ForCTC ## TFWav2Vec2ForCTC
[[autodoc]] TFWav2Vec2ForCTC [[autodoc]] TFWav2Vec2ForCTC
......
...@@ -3443,6 +3443,7 @@ else: ...@@ -3443,6 +3443,7 @@ else:
[ [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC", "TFWav2Vec2ForCTC",
"TFWav2Vec2ForSequenceClassification",
"TFWav2Vec2Model", "TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel", "TFWav2Vec2PreTrainedModel",
] ]
...@@ -6626,6 +6627,7 @@ if TYPE_CHECKING: ...@@ -6626,6 +6627,7 @@ if TYPE_CHECKING:
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
TFWav2Vec2ForSequenceClassification,
TFWav2Vec2Model, TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel, TFWav2Vec2PreTrainedModel,
) )
......
...@@ -351,6 +351,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -351,6 +351,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("xlnet", "TFXLNetForQuestionAnsweringSimple"), ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
] ]
) )
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[ [
...@@ -471,6 +472,9 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( ...@@ -471,6 +472,9 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
) )
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
class TFAutoModel(_BaseAutoModelClass): class TFAutoModel(_BaseAutoModelClass):
...@@ -480,6 +484,15 @@ class TFAutoModel(_BaseAutoModelClass): ...@@ -480,6 +484,15 @@ class TFAutoModel(_BaseAutoModelClass):
TFAutoModel = auto_class_update(TFAutoModel) TFAutoModel = auto_class_update(TFAutoModel)
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
TFAutoModelForAudioClassification = auto_class_update(
TFAutoModelForAudioClassification, head_doc="audio classification"
)
class TFAutoModelForPreTraining(_BaseAutoModelClass): class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
......
...@@ -59,6 +59,7 @@ else: ...@@ -59,6 +59,7 @@ else:
"TFWav2Vec2ForCTC", "TFWav2Vec2ForCTC",
"TFWav2Vec2Model", "TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel", "TFWav2Vec2PreTrainedModel",
"TFWav2Vec2ForSequenceClassification",
] ]
try: try:
...@@ -108,6 +109,7 @@ if TYPE_CHECKING: ...@@ -108,6 +109,7 @@ if TYPE_CHECKING:
from .modeling_tf_wav2vec2 import ( from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
TFWav2Vec2ForSequenceClassification,
TFWav2Vec2Model, TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel, TFWav2Vec2PreTrainedModel,
) )
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
get_initializer, get_initializer,
...@@ -1212,6 +1212,46 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel): ...@@ -1212,6 +1212,46 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
return self.serving_output(output) return self.serving_output(output)
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
return tf.math.floordiv(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None
):
non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
output_lengths = tf.cast(output_lengths, tf.int32)
batch_size = tf.shape(attention_mask)[0]
# check device here
attention_mask = tf.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask"
) # these two operations makes sure that all values before the output lengths idxs are attended to
## check device
attention_mask = tf.tensor_scatter_nd_update(
attention_mask,
indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1),
updates=tf.ones([batch_size], dtype=attention_mask.dtype),
)
attention_mask = tf.reverse(attention_mask, axis=[-1])
attention_mask = tf.cumsum(attention_mask, axis=-1)
attention_mask = tf.reverse(attention_mask, axis=[-1])
attention_mask = tf.cast(attention_mask, tf.bool)
return attention_mask
WAV_2_VEC_2_START_DOCSTRING = r""" WAV_2_VEC_2_START_DOCSTRING = r"""
...@@ -1552,3 +1592,125 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1552,3 +1592,125 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.num_layers = config.num_hidden_layers + 1
with tf.name_scope(self._name_scope()):
if config.use_weighted_layer_sum:
self.layer_weights = self.add_weight(
shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights"
)
self.config = config
self.projector = tf.keras.layers.Dense(units=config.classifier_proj_size, name="projector")
self.classifier = tf.keras.layers.Dense(units=config.num_labels, activation=None, name="classifier")
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
not be updated during training.
"""
warnings.warn(
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
"Please use the equivalent `freeze_feature_encoder` method instead.",
FutureWarning,
)
self.freeze_feature_encoder()
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
not be updated during training.
"""
self.wav2vec2.feature_extractor.trainable = False
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for layer in self.wav2vec2.layers:
layer.trainable = False
@unpack_inputs
def call(
self,
input_values: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: bool = False,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = tf.stack(hidden_states, axis=1)
norm_weights = tf.nn.softmax(self.layer_weights, axis=-1)
hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = outputs[0]
hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = tf.reduce_mean(hidden_states, axis=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
padding_mask_float = tf.cast(padding_mask, hidden_states.dtype)
hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1))
pooled_output = tf.divide(
tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1)
)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels]))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output):
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hidden_states,
attentions=attentions,
)
@tf.function(
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
output = self.call(input_values=inputs)
return self.serving_output(output)
...@@ -2590,6 +2590,13 @@ class TFWav2Vec2ForCTC(metaclass=DummyObject): ...@@ -2590,6 +2590,13 @@ class TFWav2Vec2ForCTC(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFWav2Vec2ForSequenceClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFWav2Vec2Model(metaclass=DummyObject): class TFWav2Vec2Model(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -50,7 +50,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -50,7 +50,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor from transformers import (
AutoFeatureExtractor,
TFWav2Vec2ForCTC,
TFWav2Vec2ForSequenceClassification,
TFWav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
...@@ -247,6 +253,29 @@ class TFWav2Vec2ModelTester: ...@@ -247,6 +253,29 @@ class TFWav2Vec2ModelTester:
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2) self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
def check_seq_classifier_loss(self, loss, config, input_values, *args):
model = TFWav2Vec2ForSequenceClassification(config)
input_values = input_values[:3]
attention_mask = tf.ones(input_values.shape, dtype=tf.int32)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = tf.random.uniform((input_values.shape[0],), maxval=len(model.config.id2label), dtype=tf.int32)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
training = False
masked_loss = (
model(input_values, attention_mask=attention_mask, labels=labels, training=training).loss.numpy().item()
)
unmasked_loss = model(input_values, labels=labels, training=training).loss.numpy().item()
assert isinstance(masked_loss, float)
assert isinstance(unmasked_loss, float)
assert masked_loss != unmasked_loss
def check_training(self, config, input_values, *args): def check_training(self, config, input_values, *args):
model = TFWav2Vec2ForCTC(config) model = TFWav2Vec2ForCTC(config)
...@@ -286,8 +315,14 @@ class TFWav2Vec2ModelTester: ...@@ -286,8 +315,14 @@ class TFWav2Vec2ModelTester:
@require_tf @require_tf
class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else () all_model_classes = (
pipeline_model_mapping = {"feature-extraction": TFWav2Vec2Model} if is_tf_available() else {} (TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
)
pipeline_model_mapping = (
{"feature-extraction": TFWav2Vec2Model, "audio-classification": TFWav2Vec2ForSequenceClassification}
if is_tf_available()
else {}
)
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
...@@ -459,7 +494,9 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -459,7 +494,9 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
@require_tf @require_tf
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else () all_model_classes = (
(TFWav2Vec2Model, TFWav2Vec2ForCTC, TFWav2Vec2ForSequenceClassification) if is_tf_available() else ()
)
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
test_onnx = False test_onnx = False
...@@ -679,6 +716,11 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -679,6 +716,11 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples] return [x["array"] for x in speech_samples]
def _load_superb(self, task, num_samples):
ds = load_dataset("anton-l/superb_dummy", task, split="test")
return ds[:num_samples]
def test_inference_ctc_normal(self): def test_inference_ctc_normal(self):
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
...@@ -791,3 +833,87 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -791,3 +833,87 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
def test_inference_keyword_spotting(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
input_data = self._load_superb("ks", 4)
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
outputs = model(input_values, attention_mask)
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
outputs.logits, axis=-1
)
expected_labels = [7, 6, 10, 9]
expected_logits = tf.convert_to_tensor([6.1186, 11.8961, 10.2931, 6.0898])
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_intent_classification(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
input_data = self._load_superb("ic", 4)
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits_action, predicted_ids_action = tf.math.reduce_max(outputs.logits[:, :6], axis=-1), tf.argmax(
outputs.logits[:, :6], axis=-1
)
predicted_logits_object, predicted_ids_object = tf.math.reduce_max(
outputs.logits[:, 6:20], axis=-1
), tf.argmax(outputs.logits[:, 6:20], axis=-1)
predicted_logits_location, predicted_ids_location = tf.math.reduce_max(
outputs.logits[:, 20:24], axis=-1
), tf.argmax(outputs.logits[:, 20:24], axis=-1)
expected_labels_action = [0, 0, 2, 3]
expected_logits_action = tf.convert_to_tensor([0.4568, 11.0848, 1.6621, 9.3841])
expected_labels_object = [3, 10, 3, 4]
expected_logits_object = tf.convert_to_tensor([1.5322, 10.7094, 5.2469, 22.1318])
expected_labels_location = [0, 0, 0, 1]
expected_logits_location = tf.convert_to_tensor([1.5335, 6.5096, 10.5704, 11.0569])
self.assertListEqual(predicted_ids_action.numpy().tolist(), expected_labels_action)
self.assertListEqual(predicted_ids_object.numpy().tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.numpy().tolist(), expected_labels_location)
self.assertTrue(np.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
self.assertTrue(np.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
self.assertTrue(np.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
def test_inference_speaker_identification(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
input_data = self._load_superb("si", 4)
output_logits = []
for example in input_data["speech"]:
input = processor(example, return_tensors="tf", padding=True)
output = model(input.input_values, attention_mask=None)
output_logits.append(output.logits[0])
output_logits = tf.stack(output_logits)
predicted_logits, predicted_ids = tf.math.reduce_max(output_logits, axis=-1), tf.argmax(output_logits, axis=-1)
expected_labels = [251, 1, 1, 3]
expected_logits = tf.convert_to_tensor([37.5627, 71.6362, 64.2419, 31.7778])
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
def test_inference_emotion_recognition(self):
model = TFWav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er", from_pt=True)
processor = AutoFeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
input_data = self._load_superb("er", 4)
inputs = processor(input_data["speech"], return_tensors="tf", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
outputs = model(input_values, attention_mask=attention_mask)
predicted_logits, predicted_ids = tf.math.reduce_max(outputs.logits, axis=-1), tf.argmax(
outputs.logits, axis=-1
)
expected_labels = [1, 1, 2, 2]
# s3prl logits for the same batch
expected_logits = tf.convert_to_tensor([2.1722, 3.0779, 8.0287, 6.6797])
self.assertListEqual(predicted_ids.numpy().tolist(), expected_labels)
self.assertTrue(np.allclose(predicted_logits, expected_logits, atol=1e-2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment