Unverified Commit b4b562d8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)

* fix_torch_device_generate_test

* remove @

* finish

* correct script

* correct script
parent 6e870100
......@@ -176,6 +176,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
batch_size = batch["input_values"].shape[0]
attention_mask = None
if batch["attention_mask"] is not None:
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
......@@ -198,6 +199,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
batch["sampled_negative_indices"] = _sample_negative_indices(
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
self.model.config.num_negatives,
attention_mask=attention_mask,
)
return batch
......
......@@ -174,7 +174,7 @@ def _compute_mask_indices(
return spec_aug_mask
def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None):
"""
Sample `num_negatives` vectors from feature vectors.
"""
......@@ -186,11 +186,13 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.random.randint(
low=0,
high=sequence_length - 1,
size=(batch_size, num_negatives * sequence_length),
)
sampled_negative_indices = []
for batch_idx in range(batch_size):
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,))
sampled_negative_indices.append(sampled_indices_slice)
sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32)
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()
......
......@@ -877,6 +877,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
return input_lengths
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
batch_size = attention_mask.shape[0]
attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
......@@ -1044,19 +1056,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
extract_features = extract_features.transpose(1, 2)
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
attention_mask = torch.zeros(
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
# compute reduced attention_mask correponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(
......@@ -1111,7 +1112,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.wav2vec2.feature_extractor._freeze_parameters()
@staticmethod
def _sample_negatives(features: torch.FloatTensor, num_negatives: int):
def _sample_negatives(
features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
......@@ -1125,12 +1128,15 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
with torch.no_grad():
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = torch.randint(
low=0,
high=sequence_length - 1,
size=(batch_size, num_negatives * sequence_length),
device=features.device,
sampled_negative_indices = []
for batch_idx in range(batch_size):
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
sampled_indices_slice = torch.randint(
0, high, size=(num_negatives * sequence_length,), device=features.device
)
sampled_negative_indices.append(sampled_indices_slice)
sampled_negative_indices = torch.stack(sampled_negative_indices)
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
feature_indices = (
......@@ -1263,7 +1269,14 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
if self.training:
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss
negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives)
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
negative_quantized_features = self._sample_negatives(
quantized_features, self.config.num_negatives, attention_mask=attention_mask
)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
......
......@@ -306,6 +306,48 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
# => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
sequence_length, hidden_size
) # each value in vector consits of same value
# second half of last input tensor is padded
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8)
attention_mask[-1, sequence_length // 2 :] = 0
forbidden_indices = (
np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length
).tolist()
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask)
# make sure that no padding tokens are sampled
self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices]))
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
# take negative vectors from sampled indices
sampled_negatives = features[negative_indices.reshape(-1)]
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
2, 0, 1, 3
)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not just slices of vectors
# => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_flax
@require_datasets
......
......@@ -633,6 +633,37 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
self.assertTrue((negatives >= 0).all().item())
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch
@require_datasets
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment