Unverified Commit d45fc7da authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Speech Examples] Add pytorch speech pretraining (#13877)

* adapt wav2vec2

* add example

* add files

* adapt

* remove bogus file

* Apply suggestions from code review

* adapt files more

* upload changes

* del old files

* up

* up

* up

* up

* up

* correct gradient checkpoitning

* add readme

* finish

* finish

* up

* more fixes

* up

* up

* add demo run to readme

* up
parent 3499728d
......@@ -3,6 +3,7 @@ scikit-learn
seqeval
psutil
sacrebleu >= 1.4.12
accelerate >= 0.5.0
rouge-score
tensorflow_datasets
matplotlib
......
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Speech Recognition Pre-Training
## Wav2Vec2 Speech Pre-Training
The script [`run_speech_wav2vec2_pretraining_no_trainer.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py) can be used to pre-train a [Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html?highlight=wav2vec2) model from scratch.
In the script [`run_speech_wav2vec2_pretraining_no_trainer`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py), a Wav2Vec2 model is pre-trained on audio data alone using [Wav2Vec2's contrastive loss objective](https://arxiv.org/abs/2006.11477).
The following examples show how to fine-tune a `"base"`-sized Wav2Vec2 model as well as a `"large"`-sized Wav2Vec2 model using [`accelerate`](https://github.com/huggingface/accelerate).
---
**NOTE 1**
Wav2Vec2's pre-training is known to be quite unstable.
It is advised to do a couple of test runs with a smaller dataset,
*i.e.* `--dataset_config_names clean clean`, `--dataset_split_names validation test`
to find good hyper-parameters for `learning_rate`, `batch_size`, `num_warmup_steps`,
and the optimizer.
A good metric to observe during training is the gradient norm which should ideally be between 0.5 and 2.
---
---
**NOTE 2**
When training a model on large datasets it is recommended to run the data preprocessing
in a first run in a **non-distributed** mode via `--preprocessing_only` so that
when running the model in **distributed** mode in a second step the preprocessed data
can easily be loaded on each distributed device.
---
### Demo
In this demo run we pre-train a `"base-sized"` Wav2Vec2 model simply only on the validation
and test data of [librispeech_asr](https://huggingface.co/datasets/librispeech_asr).
The demo is run on two Titan RTX (24 GB RAM each). In case you have less RAM available
per device, consider reducing `--batch_size` and/or the `--max_duration_in_seconds`.
```bash
accelerate launch run_wav2vec2_pretraining_no_trainer.py \
--dataset_name="librispeech_asr" \
--dataset_config_names clean clean \
--dataset_split_names validation test \
--model_name_or_path="patrickvonplaten/wav2vec2-base-v2" \
--output_dir="./wav2vec2-pretrained-demo" \
--max_train_steps="20000" \
--num_warmup_steps="32000" \
--gradient_accumulation_steps="8" \
--learning_rate="0.005" \
--weight_decay="0.01" \
--max_duration_in_seconds="20.0" \
--min_duration_in_seconds="2.0" \
--logging_steps="1" \
--saving_steps="10000" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--adam_epsilon="1e-06" \
--gradient_checkpointing \
```
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/wav2vec2-pretrained-demo/reports/Wav2Vec2-PreTraining-Demo-Run--VmlldzoxMDk3MjAw?accessToken=oa05s1y57lizo2ocxy3k01g6db1u4pt8m6ur2n8nl4cb0ug02ms2cw313kb8ruch).
### Base
TODO (currently running...)
### Large
To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60),
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
```bash
accelerate launch run_pretrain_no_trainer.py \
--dataset_name=librispeech_asr \
--dataset_config_names clean clean other \
--dataset_split_names train.100 train.360 train.500 \
--output_dir=./test \
--max_train_steps=200000 \
--num_warmup_steps=32000 \
--gradient_accumulation_steps=8 \
--learning_rate=0.001 \
--weight_decay=0.01 \
--max_duration_in_seconds=20.0 \
--min_duration_in_seconds=2.0 \
--model_name_or_path=./
--logging_steps=1 \
--saving_steps=10000 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=4 \
--adam_beta1=0.9 \
--adam_beta2=0.98 \
--adam_epsilon=1e-06 \
--gradient_checkpointing \
```
The experiment was run on 8 GPU V100 (16 GB RAM each) for 7 days.
In case you have more than 8 GPUs available for a higher effective `batch_size`,
it is recommended to increase the `learning_rate` to `0.005` for faster convergence.
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/pretraining-wav2vec2/reports/Wav2Vec2-Large--VmlldzoxMTAwODM4?accessToken=wm3qzcnldrwsa31tkvf2pdmilw3f63d4twtffs86ou016xjbyilh55uoi3mo1qzc) and the checkpoint pretrained for 120,000 steps can be accessed [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-repro-960h-libri-120k-steps)
datasets >= 1.12.0
torch >= 1.5
torchaudio
accelerate >= 0.5.0
......@@ -23,6 +23,7 @@ from unittest.mock import patch
import torch
from transformers import Wav2Vec2ForPreTraining
from transformers.file_utils import is_apex_available
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
......@@ -41,6 +42,7 @@ SRC_DIRS = [
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
]
]
sys.path.extend(SRC_DIRS)
......@@ -59,6 +61,7 @@ if SRC_DIRS is not None:
import run_summarization
import run_swag
import run_translation
import run_wav2vec2_pretraining_no_trainer
logging.basicConfig(level=logging.DEBUG)
......@@ -447,3 +450,32 @@ class ExamplesTests(TestCasePlus):
run_audio_classification.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_wav2vec2_pretraining(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_wav2vec2_pretraining_no_trainer.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
--dataset_name patrickvonplaten/librispeech_asr_dummy
--dataset_config_names clean
--dataset_split_names validation
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--preprocessing_num_workers 16
--max_train_steps 5
--validation_split_percentage 5
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_wav2vec2_pretraining_no_trainer.main()
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)
......@@ -48,13 +48,13 @@ def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
device: torch.device,
attention_mask: Optional[torch.tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> torch.tensor:
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
ASR <https://arxiv.org/abs/1904.08779>`__.
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
on CPU as part of the preprocessing during training.
Args:
shape: the the shape for which to compute masks.
......@@ -64,7 +64,6 @@ def _compute_mask_indices(
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
"""
batch_size, sequence_length = shape
......@@ -76,42 +75,64 @@ def _compute_mask_indices(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
num_masked_spans = max(num_masked_spans, min_masks)
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
# make sure num masked indices <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
dummy_mask_idx = spec_aug_mask_idx[0]
# uniform distribution to sample from, make sure that offset samples are < sequence_length
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
# get random indices to mask
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = (
spec_aug_mask_idxs.unsqueeze(dim=-1)
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
offsets = (
torch.arange(mask_length, device=device)[None, None, :]
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# scatter indices to mask
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
if attention_mask is not None:
# make sure padded input ids cannot be masked
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
......@@ -257,6 +278,7 @@ class HubertFeatureExtractor(nn.Module):
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
def _freeze_parameters(self):
for param in self.parameters():
......@@ -264,8 +286,26 @@ class HubertFeatureExtractor(nn.Module):
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states
......@@ -864,10 +904,10 @@ class HubertModel(HubertPreTrainedModel):
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
attention_mask=attention_mask,
min_masks=2,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
......@@ -876,9 +916,11 @@ class HubertModel(HubertPreTrainedModel):
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
:, None
].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
......
......@@ -586,7 +586,8 @@ class HubertUtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
......@@ -596,7 +597,8 @@ class HubertUtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
......
......@@ -40,7 +40,11 @@ if is_torch_available():
Wav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2GumbelVectorQuantizer,
_compute_mask_indices,
_sample_negative_indices,
)
class Wav2Vec2ModelTester:
......@@ -405,6 +409,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
......@@ -605,6 +615,12 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
......@@ -640,28 +656,37 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
loss = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
).loss
# more losses
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy())
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
loss_more_masked = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
).loss
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
......@@ -727,7 +752,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
......@@ -737,7 +763,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
......@@ -753,8 +780,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
mask = torch.from_numpy(mask).to(torch_device)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
......@@ -785,8 +813,11 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
# sample negative indices
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
......@@ -796,15 +827,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
def test_sample_negatives_with_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[-1, sequence_length // 2 :] = 0
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
......@@ -812,9 +843,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
# sample negative indices
sampled_negative_indices = _sample_negative_indices(
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue((negatives >= 0).all().item())
......@@ -924,16 +961,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
# Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works
# correctly
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
def test_inference_integration(self):
return
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-base", return_attention_mask=True
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
......@@ -943,19 +975,18 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
np.random.seed(4)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
......@@ -965,14 +996,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# cosine similarity of model is all > 0.5 as model is
# pre-trained on contrastive loss
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
0.6997],
device=torch_device,
)
expected_cosine_sim_masked = torch.tensor([
0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565,
0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258,
0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774,
0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526,
0.7768, 0.4898, 0.5393, 0.8183
], device=torch_device)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
......@@ -997,9 +1030,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
......@@ -1064,28 +1097,36 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
)
torch.manual_seed(0)
np.random.seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
sampled_negative_indices = _sample_negative_indices(
mask_time_indices.shape, model.config.num_negatives, mask_time_indices
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
)
# check diversity loss
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170
expected_loss = 116.7094
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment