"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "627f44799a9f4948a6a1b8fe9e536eee0e29ea68"
Unverified Commit 38861045 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Fix TFWav2Vec2 SpecAugment (#12289)

* Fix TFWav2Vec2 SpecAugment

* Invert masks

* Feedback changes
parent bc084938
...@@ -267,7 +267,7 @@ def _compute_mask_indices( ...@@ -267,7 +267,7 @@ def _compute_mask_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
) )
return tf.cast(spec_aug_mask, tf.float32) return spec_aug_mask
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
...@@ -1139,13 +1139,12 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): ...@@ -1139,13 +1139,12 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
return input_lengths return input_lengths
def _mask_hidden_states( def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: Optional[tf.Tensor] = None):
self, hidden_states: tf.Tensor, mask_time_indices: Optional[tf.Tensor] = None, training: bool = False
):
""" """
Masks extracted features along time axis and/or along feature axis according to `SpecAugment Masks extracted features along time axis and/or along feature axis according to `SpecAugment
<https://arxiv.org/abs/1904.08779>`__ . <https://arxiv.org/abs/1904.08779>`__ .
""" """
batch_size, sequence_length, hidden_size = shape_list(hidden_states)
# `config.apply_spec_augment` can set masking to False # `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True): if not getattr(self.config, "apply_spec_augment", True):
...@@ -1153,27 +1152,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): ...@@ -1153,27 +1152,34 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.masked_spec_embed) hidden_states = tf.where(
elif self.config.mask_time_prob > 0 and training: tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
# generate indices & apply SpecAugment along time axis self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
batch_size, sequence_length, hidden_size = hidden_states.shape hidden_states,
)
elif self.config.mask_time_prob > 0:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
self.config.mask_time_prob, mask_prob=self.config.mask_time_prob,
self.config.mask_time_length, mask_length=self.config.mask_time_length,
min_masks=2, min_masks=2,
) )
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.masked_spec_embed) hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
# apply SpecAugment along feature axis # apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0 and training: if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices( mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size), (batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob, mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length, mask_length=self.config.mask_feature_length,
) )
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_feature_indices, self.masked_spec_embed) hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0)
return hidden_states return hidden_states
...@@ -1185,8 +1191,8 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): ...@@ -1185,8 +1191,8 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
position_ids: Optional[tf.Tensor] = None, position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None, head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: Optional[tf.Tensor] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[tf.Tensor] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
**kwargs: Any, **kwargs: Any,
...@@ -1220,9 +1226,14 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): ...@@ -1220,9 +1226,14 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
mask_time_indices = kwargs.get("mask_time_indices", None) mask_time_indices = kwargs.get("mask_time_indices", None)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.mask_spec_embed) hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
hidden_states = self._mask_hidden_states(hidden_states) if inputs["training"]:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, hidden_states,
...@@ -1586,12 +1597,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1586,12 +1597,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
# when not being attended to # when not being attended to
labels_mask = tf.cast(labels >= 0, tf.int32) labels_mask = tf.cast(labels >= 0, tf.int32)
target_lengths = tf.reduce_sum(labels_mask, axis=-1) target_lengths = tf.reduce_sum(labels_mask, axis=-1)
flattened_labels = tf.boolean_mask(labels, labels_mask)
flattened_labels = tf.reshape(flattened_labels, [labels.shape[0], -1])
loss = tf.nn.ctc_loss( loss = tf.nn.ctc_loss(
logits=logits, logits=logits,
labels=flattened_labels, labels=labels,
logit_length=input_lengths, logit_length=input_lengths,
label_length=target_lengths, label_length=target_lengths,
blank_index=self.config.pad_token_id, blank_index=self.config.pad_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment