Unverified Commit 58bf8825 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Make sure tensors are always bool for mask_indices (#13977)

* correct long to bool

* up

* correct code
parent 11c043d2
...@@ -907,7 +907,7 @@ class HubertModel(HubertPreTrainedModel): ...@@ -907,7 +907,7 @@ class HubertModel(HubertPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
min_masks=2, min_masks=2,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
...@@ -917,7 +917,7 @@ class HubertModel(HubertPreTrainedModel): ...@@ -917,7 +917,7 @@ class HubertModel(HubertPreTrainedModel):
mask_prob=self.config.mask_feature_prob, mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length, mask_length=self.config.mask_feature_length,
) )
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)[
:, None :, None
].expand(-1, sequence_length, -1) ].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0 hidden_states[mask_feature_indices] = 0
......
...@@ -1100,7 +1100,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1100,7 +1100,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
min_masks=2, min_masks=2,
) )
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
...@@ -1110,7 +1110,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1110,7 +1110,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
mask_prob=self.config.mask_feature_prob, mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length, mask_length=self.config.mask_feature_length,
) )
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)[
:, None :, None
].expand(-1, sequence_length, -1) ].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0 hidden_states[mask_feature_indices] = 0
......
...@@ -738,6 +738,33 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -738,6 +738,33 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
self.assertEqual(logits.shape, (4, 1498, 32)) self.assertEqual(logits.shape, (4, 1498, 32))
def test_mask_time_feature_prob_ctc_single_batch(self):
model = Wav2Vec2ForCTC.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2",
mask_time_prob=0.2,
mask_feature_prob=0.2,
mask_time_length=2,
mask_feature_length=2,
)
model.to(torch_device).train()
processor = Wav2Vec2Processor.from_pretrained(
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
)
batch_duration_in_seconds = [6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
batch = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
logits = model(
input_values=batch["input_values"].to(torch_device),
attention_mask=batch["attention_mask"].to(torch_device),
).logits
self.assertEqual(logits.shape, (1, 1498, 32))
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment