Unverified Commit 48463ebb authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add Speaker Diarization and Verification heads (#14723)

* Models

* Squashed commit of the following:

commit 72278e1e931a16d0879acc77f65762f3364833d0
Author: anton-l <aglozhkov@gmail.com>
Date:   Fri Dec 10 21:45:08 2021 +0300

* Add unispeech heads

* Add sd/sv automodels

* Docs cleanup

* Fix docstrings

* rename xvector classes

* examples

* Tests cleanup

* Style

* Better checkpoints for tests

* leftover docs

* apply review suggestions

* Style + init tests

* Update unispeech-sat tdnn downsampling
parent 2e07180c
......@@ -33,9 +33,11 @@ if is_torch_available():
import torch
from transformers import (
UniSpeechSatForAudioFrameClassification,
UniSpeechSatForCTC,
UniSpeechSatForPreTraining,
UniSpeechSatForSequenceClassification,
UniSpeechSatForXVector,
UniSpeechSatModel,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
......@@ -70,6 +72,10 @@ class UniSpeechSatModelTester:
mask_time_length=2,
vocab_size=32,
do_stable_layer_norm=False,
tdnn_dim=(32, 32),
tdnn_kernel=(3, 3),
tdnn_dilation=(1, 1),
xvector_output_dim=32,
scope=None,
):
self.parent = parent
......@@ -97,6 +103,10 @@ class UniSpeechSatModelTester:
self.do_stable_layer_norm = do_stable_layer_norm
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.tdnn_dim = tdnn_dim
self.tdnn_kernel = tdnn_kernel
self.tdnn_dilation = tdnn_dilation
self.xvector_output_dim = xvector_output_dim
self.scope = scope
output_seq_length = self.seq_length
......@@ -135,6 +145,10 @@ class UniSpeechSatModelTester:
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
tdnn_dim=self.tdnn_dim,
tdnn_kernel=self.tdnn_kernel,
tdnn_dilation=self.tdnn_dilation,
xvector_output_dim=self.xvector_output_dim,
)
def create_and_check_model(self, config, input_values, attention_mask):
......@@ -277,6 +291,30 @@ class UniSpeechSatModelTester:
loss.backward()
def check_xvector_training(self, config, *args):
config.ctc_zero_infinity = True
model = UniSpeechSatForXVector(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
# use a longer sequence length to account for TDNN temporal downsampling
input_values = floats_tensor([self.batch_size, self.seq_length * 2], self.vocab_size)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args):
model = UniSpeechSatForCTC(config)
model.to(torch_device)
......@@ -300,7 +338,14 @@ class UniSpeechSatModelTester:
@require_torch
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
(
UniSpeechSatForCTC,
UniSpeechSatForPreTraining,
UniSpeechSatModel,
UniSpeechSatForSequenceClassification,
UniSpeechSatForAudioFrameClassification,
UniSpeechSatForXVector,
)
if is_torch_available()
else ()
)
......@@ -335,6 +380,10 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_xvector_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_xvector_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
......@@ -417,6 +466,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"label_embeddings_concat",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
......@@ -623,6 +673,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"label_embeddings_concat",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
......@@ -811,3 +862,56 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
# fmt: on
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
def test_inference_diarization(self):
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
torch_device
)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
input_data = self._load_superb("sd", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
# labels is a one-hot array of shape (num_frames, num_speakers)
labels = (outputs.logits > 0).long()
# s3prl logits for the same batch
expected_logits = torch.tensor(
[
[[-5.6119, -5.5845], [-3.7772, -5.4824], [-3.6914, -5.1619], [-4.7560, -5.0496]],
[[-6.3785, -4.8365], [-5.5863, -5.4149], [-5.5639, -4.8469], [-6.1511, -4.0052]],
[[-6.0355, -3.7414], [-5.5968, -4.8061], [-5.4620, -4.7310], [-5.5864, -4.6078]],
[[-5.9493, -4.8963], [-4.4050, -5.4476], [-4.1755, -5.1395], [-4.0272, -4.3705]],
],
device=torch_device,
)
self.assertEqual(labels[0, :, 0].sum(), 270)
self.assertEqual(labels[0, :, 1].sum(), 647)
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
def test_inference_speaker_verification(self):
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
input_data = self._load_superb("si", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
with torch.no_grad():
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
# id10002 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9671, 3)
# id10006 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.4941, 3)
# id10002 vs id10004
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.5616, 3)
self.assertAlmostEqual(outputs.loss.item(), 18.5925, 3)
......@@ -44,10 +44,12 @@ if is_torch_available():
from transformers import (
Wav2Vec2FeatureExtractor,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2Model,
Wav2Vec2Processor,
)
......@@ -96,6 +98,10 @@ class Wav2Vec2ModelTester:
do_stable_layer_norm=False,
num_adapter_layers=1,
adapter_stride=2,
tdnn_dim=(32, 32),
tdnn_kernel=(5, 3),
tdnn_dilation=(1, 2),
xvector_output_dim=32,
scope=None,
):
self.parent = parent
......@@ -126,6 +132,10 @@ class Wav2Vec2ModelTester:
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope
self.tdnn_dim = tdnn_dim
self.tdnn_kernel = tdnn_kernel
self.tdnn_dilation = tdnn_dilation
self.xvector_output_dim = xvector_output_dim
output_seq_length = self.seq_length
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
......@@ -168,6 +178,10 @@ class Wav2Vec2ModelTester:
vocab_size=self.vocab_size,
num_adapter_layers=self.num_adapter_layers,
adapter_stride=self.adapter_stride,
tdnn_dim=self.tdnn_dim,
tdnn_kernel=self.tdnn_kernel,
tdnn_dilation=self.tdnn_dilation,
xvector_output_dim=self.xvector_output_dim,
)
def create_and_check_model(self, config, input_values, attention_mask):
......@@ -332,6 +346,29 @@ class Wav2Vec2ModelTester:
loss.backward()
def check_xvector_training(self, config, input_values, *args):
config.ctc_zero_infinity = True
model = Wav2Vec2ForXVector(config=config)
model.to(torch_device)
model.train()
# freeze everything but the classification head
model.freeze_base_model()
input_values = input_values[:3]
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
loss = model(input_values, labels=labels).loss
self.parent.assertFalse(torch.isinf(loss).item())
loss.backward()
def check_labels_out_of_vocab(self, config, input_values, *args):
model = Wav2Vec2ForCTC(config)
model.to(torch_device)
......@@ -398,6 +435,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_xvector_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_xvector_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
......@@ -489,6 +530,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
......@@ -573,7 +615,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
(
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForPreTraining,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForXVector,
)
if is_torch_available()
else ()
)
......@@ -622,6 +672,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)
def test_xvector_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_xvector_training(*config_and_inputs)
def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
......@@ -703,6 +757,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
"objective.weight",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
......@@ -1369,3 +1424,54 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(logits.cpu().numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
input_data = self._load_superb("sd", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
# labels is a one-hot array of shape (num_frames, num_speakers)
labels = (outputs.logits > 0).long()
# s3prl logits for the same batch
expected_logits = torch.tensor(
[
[[-5.2807, -5.1272], [-5.4059, -4.7757], [-5.2764, -4.9621], [-5.0117, -4.5851]],
[[-1.7643, -0.5462], [-1.7369, -0.2649], [-1.5066, -0.6200], [-4.5703, -2.4863]],
[[-0.8656, -0.4783], [-0.8899, -0.3289], [-0.9267, -0.5781], [-0.7817, -0.4619]],
[[-4.8625, -2.5316], [-5.2339, -2.2155], [-4.9835, -2.0344], [-4.4727, -1.8421]],
],
device=torch_device,
)
self.assertEqual(labels[0, :, 0].sum(), 555)
self.assertEqual(labels[0, :, 1].sum(), 299)
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
def test_inference_speaker_verification(self):
model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device)
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sv")
input_data = self._load_superb("si", 4)
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
with torch.no_grad():
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1).cpu()
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
# id10002 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).numpy(), 0.9758, 3)
# id10006 vs id10002
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).numpy(), 0.7579, 3)
# id10002 vs id10004
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).numpy(), 0.7594, 3)
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 3)
......@@ -74,6 +74,12 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
"AutoModelForNextSentencePrediction",
),
(
"audio-frame-classification",
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
"AutoModelForAudioFrameClassification",
),
("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment