Unverified Commit 003a7cc6 authored by bofeng huang's avatar bofeng huang Committed by GitHub
Browse files

[Whisper] Fix feature normalization in `WhisperFeatureExtractor` (#21938)

Fix feature normalization in WhisperFeatureExtractor
parent 718e9d77
...@@ -334,14 +334,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -334,14 +334,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.n_samples,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask, return_attention_mask=return_attention_mask or do_normalize,
) )
# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
# zero-mean and unit-variance normalization # zero-mean and unit-variance normalization
if do_normalize: if do_normalize:
...@@ -350,6 +344,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -350,6 +344,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
attention_mask=padded_inputs["attention_mask"], attention_mask=padded_inputs["attention_mask"],
padding_value=self.padding_value, padding_value=self.padding_value,
) )
padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]] input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment