You need to sign in or sign up before continuing.
Unverified Commit 8767958f authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Allow dict input for audio classification pipeline (#23445)



* Allow dict input for audio classification pipeline

* make style

* Empty commit to trigger CI

* Empty commit to trigger CI

* check for torchaudio

* add pip instructions
Co-authored-by: default avatarSylvain <sylvain.gugger@gmail.com>

* Update src/transformers/pipelines/audio_classification.py
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* asr -> audio class

* asr -> audio class

---------
Co-authored-by: default avatarSylvain <sylvain.gugger@gmail.com>
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent a6f37f88
...@@ -17,7 +17,7 @@ from typing import Union ...@@ -17,7 +17,7 @@ from typing import Union
import numpy as np import numpy as np
import requests import requests
from ..utils import add_end_docstrings, is_torch_available, logging from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
...@@ -110,12 +110,18 @@ class AudioClassificationPipeline(Pipeline): ...@@ -110,12 +110,18 @@ class AudioClassificationPipeline(Pipeline):
information. information.
Args: Args:
inputs (`np.ndarray` or `bytes` or `str`): inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) The inputs is either :
at the correct sampling rate (no further check will be done) or a `str` that is the filename of the - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
content of an audio file and is interpreted by *ffmpeg* in the same way. same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio at the correct sampling rate (no further check will be done)
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
`"array"` is used to denote the raw audio waveform.
top_k (`int`, *optional*, defaults to None): top_k (`int`, *optional*, defaults to None):
The number of top labels that will be returned by the pipeline. If the provided number is `None` or The number of top labels that will be returned by the pipeline. If the provided number is `None` or
higher than the number of labels available in the model configuration, it will default to the number of higher than the number of labels available in the model configuration, it will default to the number of
...@@ -151,10 +157,42 @@ class AudioClassificationPipeline(Pipeline): ...@@ -151,10 +157,42 @@ class AudioClassificationPipeline(Pipeline):
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
if isinstance(inputs, dict):
# Accepting `"array"` which is the key defined in `datasets` for
# better integration
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
inputs.pop("path", None)
_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
if is_torchaudio_available():
from torchaudio import functional as F
else:
raise ImportError(
"torchaudio is required to resample audio samples in AudioClassificationPipeline. "
"The torchaudio package can be installed through: `pip install torchaudio`."
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
).numpy()
if not isinstance(inputs, np.ndarray): if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input") raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1: if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
processed = self.feature_extractor( processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
......
...@@ -103,6 +103,10 @@ class AudioClassificationPipelineTests(unittest.TestCase): ...@@ -103,6 +103,10 @@ class AudioClassificationPipelineTests(unittest.TestCase):
] ]
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
output = audio_classifier(audio_dict, top_k=4)
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
@require_torch @require_torch
@slow @slow
def test_large_model_pt(self): def test_large_model_pt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment