"test/srt/git@developer.sourcefind.cn:change/sglang.git" did not exist on "7c3f07dbcba5fb36b889ab219a758663f111e599"
Unverified Commit 5de2a6d5 authored by LWprogramming's avatar LWprogramming Committed by GitHub
Browse files

Fix wav2vec2 is_batched check to include 2-D numpy arrays (#23223)



* Fix wav2vec2 is_batched check to include 2-D numpy arrays

* address comment

* Add tests

* oops

* oops

* Switch to np array
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Switch to np array

* condition merge

* Specify mono channel only in comment

* oops, add other comment too

* make style

* Switch list check from falsiness to empty

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 4ddd9de9
...@@ -140,7 +140,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin): ...@@ -140,7 +140,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
return_attention_mask if return_attention_mask is not None else self.return_attention_mask return_attention_mask if return_attention_mask is not None else self.return_attention_mask
) )
if not required_input: if len(required_input) == 0:
if return_attention_mask: if return_attention_mask:
processed_features["attention_mask"] = [] processed_features["attention_mask"] = []
return processed_features return processed_features
......
...@@ -117,7 +117,8 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -117,7 +117,8 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
Args: Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among: index) among:
...@@ -181,9 +182,11 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): ...@@ -181,9 +182,11 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
"Failing to do so can result in silent errors that might be hard to debug." "Failing to do so can result in silent errors that might be hard to debug."
) )
is_batched = bool( is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
isinstance(raw_speech, (list, tuple)) if is_batched_numpy and len(raw_speech.shape) > 2:
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
) )
# always return batch # always return batch
......
...@@ -817,12 +817,15 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): ...@@ -817,12 +817,15 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
Args: Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrayr or a list of list of float values. values, a list of numpy array or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
""" """
is_batched = bool( is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
isinstance(raw_speech, (list, tuple)) if is_batched_numpy and len(raw_speech.shape) > 2:
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
) )
# make sure input is in list format # make sure input is in list format
......
...@@ -123,6 +123,14 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -123,6 +123,14 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_zero_mean_unit_variance_normalization_np(self): def test_zero_mean_unit_variance_normalization_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
......
...@@ -164,6 +164,14 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): ...@@ -164,6 +164,14 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test 2-D numpy arrays are batched.
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
np_speech_inputs = np.asarray(speech_inputs)
encoded_sequences_1 = tokenizer(speech_inputs, return_tensors="np").input_values
encoded_sequences_2 = tokenizer(np_speech_inputs, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_padding(self, max_length=50): def test_padding(self, max_length=50):
def _input_values_have_equal_length(input_values): def _input_values_have_equal_length(input_values):
length = len(input_values[0]) length = len(input_values[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment