Unverified Commit d94773e6 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Provide mask_time_indices to `_mask_hidden_states` to avoid double masking (#12692)



* We need to provide mask_time_indices to `_mask_hidden_states` to avoid applying the mask two times

* apply the same to wav2vec2

* Uniformize the style between hubert and wav2vec2

* fix tf as well
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 144cea25
......@@ -911,11 +911,7 @@ class HubertModel(HubertPreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(extract_features)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
hidden_states = self._mask_hidden_states(hidden_states)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,
......
......@@ -1227,13 +1227,6 @@ class TFHubertMainLayer(tf.keras.layers.Layer):
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
mask_time_indices = kwargs.get("mask_time_indices", None)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
if inputs["training"]:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
......
......@@ -1218,13 +1218,6 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
mask_time_indices = kwargs.get("mask_time_indices", None)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = tf.where(
tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
hidden_states,
)
if inputs["training"]:
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
......
......@@ -1049,11 +1049,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states, extract_features = self.feature_projection(extract_features)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
hidden_states = self._mask_hidden_states(hidden_states)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
encoder_outputs = self.encoder(
hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment