Unverified Commit d7b3b709 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Fix normalization for non-padded tensors (#13512)

* finalize

* Apply suggestions from code review

* finish cleaner implementation

* more tests

* small fix

* finish

* up
parent c63fcabf
...@@ -341,7 +341,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin): ...@@ -341,7 +341,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
return processed_features return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs): def _get_padding_strategies(self, padding=False, max_length=None):
""" """
Find the correct padding strategy Find the correct padding strategy
""" """
......
...@@ -93,10 +93,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -93,10 +93,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
@staticmethod @staticmethod
def utterance_cmvn( def utterance_cmvn(
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True x: np.ndarray,
input_length: int,
normalize_means: Optional[bool] = True,
normalize_vars: Optional[bool] = True,
padding_value: float = 0.0,
) -> np.ndarray: ) -> np.ndarray:
# make sure we normalie float32 arrays # make sure we normalie float32 arrays
mean = x[:input_length].mean(axis=0) mean = x[:input_length].mean(axis=0)
square_sums = (x[:input_length] ** 2).sum(axis=0) square_sums = (x[:input_length] ** 2).sum(axis=0)
...@@ -107,15 +110,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -107,15 +110,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
std = np.sqrt(np.maximum(var, 1e-10)) std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std) x = np.divide(x, std)
if x.shape[0] > input_length:
x[input_length:] = padding_value
# make sure array is in float32 # make sure array is in float32
x = x.astype(np.float32) x = x.astype(np.float32)
return x return x
def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: def normalize(
self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
) -> List[np.ndarray]:
lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
return [ return [
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars) self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value)
for x, n in zip(input_values, input_lengths) for x, n in zip(input_features, lengths)
] ]
def __call__( def __call__(
...@@ -197,7 +206,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -197,7 +206,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
) )
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray): if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech] raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray): elif not is_batched and not isinstance(raw_speech, np.ndarray):
...@@ -225,21 +233,25 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -225,21 +233,25 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
**kwargs, **kwargs,
) )
if "attention_mask" in padded_inputs: # make sure list is in array format
input_lengths = padded_inputs["attention_mask"].sum(-1) input_features = padded_inputs.get("input_features")
else: if isinstance(input_features[0], list):
padded_input_values = padded_inputs["input_features"] padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
# Utterance-level cepstral mean and variance normalization # Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize: if self.do_ceptral_normalize:
input_features = padded_inputs["input_features"] attention_mask = (
np.array(attention_mask, dtype=np.bool)
# make sure list is in array format if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
if isinstance(input_features[0], list): else None
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features] )
padded_inputs["input_features"] = self.normalize(
padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths) padded_inputs["input_features"], attention_mask=attention_mask
)
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
......
...@@ -79,13 +79,25 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -79,13 +79,25 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
@staticmethod @staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: def zero_mean_unit_var_norm(
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
) -> List[np.ndarray]:
""" """
Every array in the list is normalized to have zero mean and unit variance Every array in the list is normalized to have zero mean and unit variance
""" """
normed_input_values = [ if attention_mask is not None:
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths) attention_mask = np.array(attention_mask, np.bool)
] normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length > normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
return normed_input_values return normed_input_values
def __call__( def __call__(
...@@ -172,14 +184,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -172,14 +184,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
) )
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
raw_speech = raw_speech.astype(np.float32)
# always return batch # always return batch
if not is_batched: if not is_batched:
raw_speech = [raw_speech] raw_speech = [raw_speech]
...@@ -196,19 +200,33 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -196,19 +200,33 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
return_attention_mask=return_attention_mask, return_attention_mask=return_attention_mask,
) )
if "attention_mask" in padded_inputs: # convert input values to correct format
input_lengths = padded_inputs["attention_mask"].sum(-1) input_values = padded_inputs["input_values"]
else: if not isinstance(input_values[0], np.ndarray):
padded_input_values = padded_inputs["input_values"] padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] elif (
not isinstance(input_values, np.ndarray)
if isinstance(padded_inputs["input_values"][0], np.ndarray): and isinstance(input_values[0], np.ndarray)
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]] and input_values[0].dtype is np.float64
):
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
padded_inputs["input_values"] = input_values.astype(np.float32)
# convert attention_mask to correct format
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
# zero-mean and unit-variance normalization # zero-mean and unit-variance normalization
if self.do_normalize: if self.do_normalize:
attention_mask = (
attention_mask
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None
)
padded_inputs["input_values"] = self.zero_mean_unit_var_norm( padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
padded_inputs["input_values"], input_lengths=input_lengths padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
) )
if return_tensors is not None: if return_tensors is not None:
......
...@@ -136,18 +136,49 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt ...@@ -136,18 +136,49 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
def test_cepstral_mean_and_variance_normalization(self): def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
def _check_zero_mean_unit_variance(input_vector): paddings = ["longest", "max_length", "do_not_pad"]
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3)) max_lengths = [None, 16, None]
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3)) var_tolerances = [1e-3, 1e-3, 1e-1]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]]) inputs = feature_extractor(
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]]) speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]]) )
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
def test_cepstral_mean_and_variance_normalization_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 16, None]
var_tolerances = [1e-3, 1e-3, 1e-1]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
inputs = feature_extractor(
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = [np.sum(x) for x in attention_mask]
def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))
_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)
def test_cepstral_mean_and_variance_normalization_trunc(self): def test_cepstral_mean_and_variance_normalization_trunc(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
......
...@@ -120,21 +120,45 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -120,21 +120,45 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_zero_mean_unit_variance_normalization(self): def test_zero_mean_unit_variance_normalization_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector): paddings = ["longest", "max_length", "do_not_pad"]
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) max_lengths = [None, 1600, None]
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
input_values = processed.input_values
_check_zero_mean_unit_variance(input_values[0, :800]) def _check_zero_mean_unit_variance(input_vector):
_check_zero_mean_unit_variance(input_values[1, :1000]) self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
_check_zero_mean_unit_variance(input_values[2]) self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])
def test_zero_mean_unit_variance_normalization(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
lengths = range(800, 1400, 200)
speech_inputs = [floats_list((1, x))[0] for x in lengths]
paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 1600, None]
for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])
def test_zero_mean_unit_variance_normalization_trunc(self): def test_zero_mean_unit_variance_normalization_trunc_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract( processed = feat_extract(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment