Unverified Commit 11edecd7 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Fix uninitialized variables when `config.mask_feature_prob > 0` (#12705)

parent f9ac677e
...@@ -811,6 +811,7 @@ class HubertModel(HubertPreTrainedModel): ...@@ -811,6 +811,7 @@ class HubertModel(HubertPreTrainedModel):
self.init_weights() self.init_weights()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states( def _mask_hidden_states(
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
): ):
...@@ -823,13 +824,13 @@ class HubertModel(HubertPreTrainedModel): ...@@ -823,13 +824,13 @@ class HubertModel(HubertPreTrainedModel):
if not getattr(self.config, "apply_spec_augment", True): if not getattr(self.config, "apply_spec_augment", True):
return hidden_states return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
mask_prob=self.config.mask_time_prob, mask_prob=self.config.mask_time_prob,
......
...@@ -961,13 +961,13 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -961,13 +961,13 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if not getattr(self.config, "apply_spec_augment", True): if not getattr(self.config, "apply_spec_augment", True):
return hidden_states return hidden_states
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.size()
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
mask_prob=self.config.mask_time_prob, mask_prob=self.config.mask_time_prob,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment