Unverified Commit 5c14fcea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

return attention mask in int32 (#13543)

parent 149c833b
...@@ -240,12 +240,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -240,12 +240,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
attention_mask = padded_inputs.get("attention_mask") attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None: if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask] padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
# Utterance-level cepstral mean and variance normalization # Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize: if self.do_ceptral_normalize:
attention_mask = ( attention_mask = (
np.array(attention_mask, dtype=np.bool) np.array(attention_mask, dtype=np.int32)
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None else None
) )
......
...@@ -86,7 +86,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -86,7 +86,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
Every array in the list is normalized to have zero mean and unit variance Every array in the list is normalized to have zero mean and unit variance
""" """
if attention_mask is not None: if attention_mask is not None:
attention_mask = np.array(attention_mask, np.bool) attention_mask = np.array(attention_mask, np.int32)
normed_input_values = [] normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)): for vector, length in zip(input_values, attention_mask.sum(-1)):
...@@ -216,7 +216,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -216,7 +216,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
# convert attention_mask to correct format # convert attention_mask to correct format
attention_mask = padded_inputs.get("attention_mask") attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None: if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask] padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
# zero-mean and unit-variance normalization # zero-mean and unit-variance normalization
if self.do_normalize: if self.do_normalize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment