"docs/source/vscode:/vscode.git/clone" did not exist on "baf4bacb1f10ecb63f0efc98d07463ae8799c7e3"
Unverified Commit a57d784d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Fix dtype 64 bug (#13517)

* fix

* 2nd fix
parent 72ec2f3e
......@@ -210,7 +210,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)
# always return batch
......
......@@ -207,10 +207,10 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
elif (
not isinstance(input_values, np.ndarray)
and isinstance(input_values[0], np.ndarray)
and input_values[0].dtype is np.float64
and input_values[0].dtype is np.dtype(np.float64)
):
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64):
padded_inputs["input_values"] = input_values.astype(np.float32)
# convert attention_mask to correct format
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment