Unverified Commit fbd8c51f authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Restore casting of masked_spec_embed (#30336)

* fix Parameter dtype in audio models

* restore casting of masked_spec_embed

* restore casting of masked_spec_embed
parent 0927bfd0
...@@ -858,7 +858,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): ...@@ -858,7 +858,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -868,7 +868,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel): ...@@ -868,7 +868,7 @@ class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1005,7 +1005,7 @@ class HubertModel(HubertPreTrainedModel): ...@@ -1005,7 +1005,7 @@ class HubertModel(HubertPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1015,7 +1015,7 @@ class HubertModel(HubertPreTrainedModel): ...@@ -1015,7 +1015,7 @@ class HubertModel(HubertPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -862,7 +862,7 @@ class SEWModel(SEWPreTrainedModel): ...@@ -862,7 +862,7 @@ class SEWModel(SEWPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -872,7 +872,7 @@ class SEWModel(SEWPreTrainedModel): ...@@ -872,7 +872,7 @@ class SEWModel(SEWPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1388,7 +1388,7 @@ class SEWDModel(SEWDPreTrainedModel): ...@@ -1388,7 +1388,7 @@ class SEWDModel(SEWDPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1398,7 +1398,7 @@ class SEWDModel(SEWDPreTrainedModel): ...@@ -1398,7 +1398,7 @@ class SEWDModel(SEWDPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -616,7 +616,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module): ...@@ -616,7 +616,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -626,7 +626,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module): ...@@ -626,7 +626,7 @@ class SpeechT5SpeechEncoderPrenet(nn.Module):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1121,7 +1121,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): ...@@ -1121,7 +1121,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1131,7 +1131,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): ...@@ -1131,7 +1131,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1139,7 +1139,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): ...@@ -1139,7 +1139,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1149,7 +1149,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): ...@@ -1149,7 +1149,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1496,7 +1496,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1496,7 +1496,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1506,7 +1506,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1506,7 +1506,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1087,7 +1087,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): ...@@ -1087,7 +1087,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1097,7 +1097,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel): ...@@ -1097,7 +1097,7 @@ class Wav2Vec2BertModel(Wav2Vec2BertPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1273,7 +1273,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): ...@@ -1273,7 +1273,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1283,7 +1283,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): ...@@ -1283,7 +1283,7 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
...@@ -1158,7 +1158,7 @@ class WavLMModel(WavLMPreTrainedModel): ...@@ -1158,7 +1158,7 @@ class WavLMModel(WavLMPreTrainedModel):
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
...@@ -1168,7 +1168,7 @@ class WavLMModel(WavLMPreTrainedModel): ...@@ -1168,7 +1168,7 @@ class WavLMModel(WavLMPreTrainedModel):
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment