Unverified Commit 2e9fb13f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)



* fix_torch_device_generate_test

* remove @

* start adding tests

* correct wav2vec2 pretraining

* up

* up
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 5f2791c7
......@@ -174,11 +174,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
batch_size = batch["input_values"].shape[0]
if batch["attention_mask"] is not None:
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)
......
......@@ -172,12 +172,33 @@ class DataCollatorForWav2Vec2Pretraining:
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
batch_size = batch["input_values"].shape[0]
# make sure that no loss is computed on padded inputs
if batch["attention_mask"] is not None:
# compute real output lengths according to convolution formula
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)).to(
torch.long
)
attention_mask = torch.zeros(
(batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
device=batch["input_values"].device,
attention_mask=attention_mask,
min_masks=2,
)
......
......@@ -47,6 +47,7 @@ def _compute_mask_indices(
mask_prob: float,
mask_length: int,
device: torch.device,
attention_mask: Optional[torch.tensor] = None,
min_masks: int = 0,
) -> torch.tensor:
"""
......@@ -813,7 +814,10 @@ class HubertModel(HubertPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
......@@ -836,6 +840,7 @@ class HubertModel(HubertPreTrainedModel):
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
......@@ -847,6 +852,7 @@ class HubertModel(HubertPreTrainedModel):
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
attention_mask=attention_mask,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
......
......@@ -107,6 +107,7 @@ def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[np.ndarray] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
......@@ -166,6 +167,10 @@ def _compute_mask_indices(
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
if attention_mask is not None:
# make sure padded input ids cannot be masked
spec_aug_mask = np.where(attention_mask, spec_aug_mask, False)
return spec_aug_mask
......@@ -873,6 +878,7 @@ class FlaxWav2Vec2Module(nn.Module):
"""
extract_features = self.feature_extractor(input_values)
# make sure that no loss is computed on padded inputs
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))
......
......@@ -120,6 +120,7 @@ def _compute_mask_indices(
mask_prob: float,
mask_length: int,
device: torch.device,
attention_mask: Optional[torch.tensor] = None,
min_masks: int = 0,
) -> torch.tensor:
"""
......@@ -179,6 +180,10 @@ def _compute_mask_indices(
# scatter indices to mask
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
if attention_mask is not None:
# make sure padded input ids cannot be masked
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
return spec_aug_mask
......@@ -950,7 +955,10 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.init_weights()
def _mask_hidden_states(
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
self,
hidden_states: torch.FloatTensor,
mask_time_indices: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
......@@ -973,6 +981,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
......@@ -984,6 +993,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
attention_mask=attention_mask,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
......@@ -1049,7 +1059,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states, extract_features = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
encoder_outputs = self.encoder(
hidden_states,
......
......@@ -245,6 +245,24 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_attn_mask_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_perplexity(self):
probs = np.arange(100).reshape(2, 5, 10) / 100
......
......@@ -580,6 +580,24 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_attn_mask_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_perplexity(self):
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment