"tests/test_modeling_tf_convbert.py" did not exist on "e983da0e7d91c100e6e35efcb8a69c8cd41d6e09"
Unverified Commit 5c14fcea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

return attention mask in int32 (#13543)

parent 149c833b
......@@ -240,12 +240,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
# Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize:
attention_mask = (
np.array(attention_mask, dtype=np.bool)
np.array(attention_mask, dtype=np.int32)
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None
)
......
......@@ -86,7 +86,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
Every array in the list is normalized to have zero mean and unit variance
"""
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.bool)
attention_mask = np.array(attention_mask, np.int32)
normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
......@@ -216,7 +216,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
# convert attention_mask to correct format
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
# zero-mean and unit-variance normalization
if self.do_normalize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment