Unverified Commit 3d574044 authored by LWprogramming's avatar LWprogramming Committed by GitHub
Browse files

is_batched fix for remaining 2-D numpy arrays (#23309)

* Fix is_batched code to allow 2-D numpy arrays for audio

* Tests

* Fix typo

* Incorporate comments from PR #23223
parent 6b7d6f84
......@@ -135,7 +135,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
......@@ -160,9 +161,11 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
......
......@@ -272,7 +272,8 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
truncation (`str`, *optional*):
Truncation pattern for long audio inputs. Two patterns are available:
- `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and
......@@ -312,9 +313,11 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
......
......@@ -180,7 +180,8 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
Args:
raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
of float values, a list of tensors, a list of numpy arrays or a list of list of float values.
of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be
mono channel audio, not stereo, i.e. single float per timestep.
padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
......@@ -231,9 +232,11 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
......
......@@ -141,7 +141,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
......@@ -200,9 +201,11 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
......
......@@ -201,7 +201,8 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
Args:
audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*):
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. This outputs waveform features.
values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must
be mono channel audio, not stereo, i.e. single float per timestep.
audio_target (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*):
The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a
list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel
......@@ -307,9 +308,11 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
is_batched = bool(
isinstance(speech, (list, tuple))
and (isinstance(speech[0], np.ndarray) or isinstance(speech[0], (tuple, list)))
is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1
if is_batched_numpy and len(speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
......
......@@ -129,7 +129,8 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'pt'`: Return PyTorch `torch.Tensor` objects.
......@@ -176,9 +177,11 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
......
......@@ -152,7 +152,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
truncation (`bool`, *optional*, default to `True`):
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
pad_to_multiple_of (`int`, *optional*, defaults to None):
......@@ -203,9 +204,11 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
......
......@@ -125,6 +125,14 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
def test_double_precision_pad(self):
import torch
......
......@@ -139,6 +139,14 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_double_precision_pad(self):
import torch
......
......@@ -134,6 +134,14 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
......
......@@ -136,6 +136,14 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
......
......@@ -275,6 +275,14 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_values
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_batch_feature_target(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_target()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
......
......@@ -189,6 +189,15 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
self.assertTrue(encoded_audios.shape[-2] <= feature_extractor.spectrogram_length)
self.assertTrue(encoded_audios.shape[-3] == feature_extractor.num_channels)
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_audios = feature_extractor(np_speech_inputs, return_tensors="np", sampling_rate=44100).audio_values
self.assertTrue(encoded_audios.ndim == 4)
self.assertTrue(encoded_audios.shape[-1] == feature_extractor.feature_size)
self.assertTrue(encoded_audios.shape[-2] <= feature_extractor.spectrogram_length)
self.assertTrue(encoded_audios.shape[-3] == feature_extractor.num_channels)
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
......
......@@ -173,6 +173,14 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test truncation required
speech_inputs = [floats_list((1, x))[0] for x in range(200, (feature_extractor.n_samples + 500), 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment