Unverified Commit c937f0b9 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Don't return attention mask in feat extractor (#19521)

* [Whisper] Don't return attention mask in feat extractor

* remove attention mask from test

* fix failing tests

* quality
parent 83a2e694
...@@ -65,13 +65,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -65,13 +65,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
chunk_length=30, chunk_length=30,
n_fft=400, n_fft=400,
padding_value=0.0, padding_value=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs **kwargs
): ):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.n_fft = n_fft self.n_fft = n_fft
self.hop_length = hop_length self.hop_length = hop_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.return_attention_mask = True
self.n_samples = chunk_length * sampling_rate self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
...@@ -301,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -301,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.n_samples,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
**kwargs, **kwargs,
) )
# make sure list is in array format # make sure list is in array format
......
...@@ -66,7 +66,7 @@ class WhisperFeatureExtractionTester(unittest.TestCase): ...@@ -66,7 +66,7 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
chunk_length=8, chunk_length=8,
padding_value=0.0, padding_value=0.0,
sampling_rate=4_000, sampling_rate=4_000,
return_attention_mask=True, return_attention_mask=False,
do_normalize=True, do_normalize=True,
): ):
self.parent = parent self.parent = parent
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment