Unverified Commit 8767958f authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Allow dict input for audio classification pipeline (#23445)



* Allow dict input for audio classification pipeline

* make style

* Empty commit to trigger CI

* Empty commit to trigger CI

* check for torchaudio

* add pip instructions
Co-authored-by: default avatarSylvain <sylvain.gugger@gmail.com>

* Update src/transformers/pipelines/audio_classification.py
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* asr -> audio class

* asr -> audio class

---------
Co-authored-by: default avatarSylvain <sylvain.gugger@gmail.com>
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent a6f37f88
......@@ -17,7 +17,7 @@ from typing import Union
import numpy as np
import requests
from ..utils import add_end_docstrings, is_torch_available, logging
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
......@@ -110,12 +110,18 @@ class AudioClassificationPipeline(Pipeline):
information.
Args:
inputs (`np.ndarray` or `bytes` or `str`):
The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
at the correct sampling rate (no further check will be done) or a `str` that is the filename of the
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This
requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the
content of an audio file and is interpreted by *ffmpeg* in the same way.
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either :
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio at the correct sampling rate (no further check will be done)
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
`"array"` is used to denote the raw audio waveform.
top_k (`int`, *optional*, defaults to None):
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
higher than the number of labels available in the model configuration, it will default to the number of
......@@ -151,10 +157,42 @@ class AudioClassificationPipeline(Pipeline):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
if isinstance(inputs, dict):
# Accepting `"array"` which is the key defined in `datasets` for
# better integration
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)
_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
inputs.pop("path", None)
_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
if is_torchaudio_available():
from torchaudio import functional as F
else:
raise ImportError(
"torchaudio is required to resample audio samples in AudioClassificationPipeline. "
"The torchaudio package can be installed through: `pip install torchaudio`."
)
inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
).numpy()
if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
......
......@@ -103,6 +103,10 @@ class AudioClassificationPipelineTests(unittest.TestCase):
]
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
output = audio_classifier(audio_dict, top_k=4)
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
@require_torch
@slow
def test_large_model_pt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment