Unverified Commit 38861045 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Fix TFWav2Vec2 SpecAugment (#12289)

* Fix TFWav2Vec2 SpecAugment

* Invert masks

* Feedback changes
parent bc084938
......@@ -267,7 +267,7 @@ def _compute_mask_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
)
return tf.cast(spec_aug_mask, tf.float32)
return spec_aug_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
......@@ -1139,13 +1139,12 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
return input_lengths
def _mask_hidden_states(
self, hidden_states: tf.Tensor, mask_time_indices: Optional[tf.Tensor] = None, training: bool = False
):
def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: Optional[tf.Tensor] = None):
"""
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
<https://arxiv.org/abs/1904.08779>`__ .
"""
batch_size, sequence_length, hidden_size = shape_list(hidden_states)
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
......@@ -1153,27 +1152,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.masked_spec_embed)
elif self.config.mask_time_prob > 0 and training:
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.shape
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
elif self.config.mask_time_prob > 0:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
min_masks=2,
)
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.masked_spec_embed)
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
# apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0 and training:
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
)
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_feature_indices, self.masked_spec_embed)
hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0)
return hidden_states
......@@ -1185,8 +1191,8 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: Optional[tf.Tensor] = None,
output_hidden_states: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs: Any,
......@@ -1220,9 +1226,14 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
mask_time_indices = kwargs.get("mask_time_indices", None)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.mask_spec_embed)
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
hidden_states = self._mask_hidden_states(hidden_states)
if inputs["training"]:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,
......@@ -1586,12 +1597,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
# when not being attended to
labels_mask = tf.cast(labels >= 0, tf.int32)
target_lengths = tf.reduce_sum(labels_mask, axis=-1)
flattened_labels = tf.boolean_mask(labels, labels_mask)
flattened_labels = tf.reshape(flattened_labels, [labels.shape[0], -1])
loss = tf.nn.ctc_loss(
logits=logits,
labels=flattened_labels,
labels=labels,
logit_length=input_lengths,
label_length=target_lengths,
blank_index=self.config.pad_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment