audio_classification.py 8.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
15
from typing import Union
16
17

import numpy as np
18
19
import requests

20
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from .base import PIPELINE_INIT_ARGS, Pipeline


if is_torch_available():
    from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING

logger = logging.get_logger(__name__)


def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
    """
    Helper function to read an audio file through ffmpeg.
    """
    ar = f"{sampling_rate}"
    ac = "1"
    format_for_conversion = "f32le"
    ffmpeg_command = [
        "ffmpeg",
        "-i",
        "pipe:0",
        "-ac",
        ac,
        "-ar",
        ar,
        "-f",
        format_for_conversion,
        "-hide_banner",
        "-loglevel",
        "quiet",
        "pipe:1",
    ]

    try:
        ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    except FileNotFoundError:
        raise ValueError("ffmpeg was not found but is required to load audio files from filename")
    output_stream = ffmpeg_process.communicate(bpayload)
    out_bytes = output_stream[0]

    audio = np.frombuffer(out_bytes, np.float32)
    if audio.shape[0] == 0:
        raise ValueError("Malformed soundfile")
    return audio


@add_end_docstrings(PIPELINE_INIT_ARGS)
class AudioClassificationPipeline(Pipeline):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
    Audio classification pipeline using any `AutoModelForAudioClassification`. This pipeline predicts the class of a
    raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio
71
72
    formats.

73
74
75
76
77
78
    Example:

    ```python
    >>> from transformers import pipeline

    >>> classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
Nicolas Patry's avatar
Nicolas Patry committed
79
    >>> classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
80
81
82
    [{'score': 0.997, 'label': '_unknown_'}, {'score': 0.002, 'label': 'left'}, {'score': 0.0, 'label': 'yes'}, {'score': 0.0, 'label': 'down'}, {'score': 0.0, 'label': 'stop'}]
    ```

Steven Liu's avatar
Steven Liu committed
83
    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
84
85


86
87
    This pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"audio-classification"`.
88

Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
    See the list of available models on
    [huggingface.co/models](https://huggingface.co/models?filter=audio-classification).
91
92
    """

93
94
95
96
    def __init__(self, *args, **kwargs):
        # Default, might be overriden by the model.config.
        kwargs["top_k"] = 5
        super().__init__(*args, **kwargs)
97
98
99
100
101
102
103
104
105
106
107
108

        if self.framework != "pt":
            raise ValueError(f"The {self.__class__} is only available in PyTorch.")

        self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING)

    def __call__(
        self,
        inputs: Union[np.ndarray, bytes, str],
        **kwargs,
    ):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
        Classify the sequence(s) given as inputs. See the [`AutomaticSpeechRecognitionPipeline`] documentation for more
        information.
111
112

        Args:
113
114
115
116
117
118
119
120
121
122
123
124
            inputs (`np.ndarray` or `bytes` or `str` or `dict`):
                The inputs is either :
                    - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
                      to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
                    - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
                      same way.
                    - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
                        Raw audio at the correct sampling rate (no further check will be done)
                    - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
                      pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
                      "raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
                      `"array"` is used to denote the raw audio waveform.
125
            top_k (`int`, *optional*, defaults to None):
126
                The number of top labels that will be returned by the pipeline. If the provided number is `None` or
127
128
129
130
                higher than the number of labels available in the model configuration, it will default to the number of
                labels.

        Return:
131
            A list of `dict` with the following keys:
132

133
134
            - **label** (`str`) -- The label predicted.
            - **score** (`float`) -- The corresponding probability.
135
        """
136
137
138
139
140
141
142
143
144
145
146
147
        return super().__call__(inputs, **kwargs)

    def _sanitize_parameters(self, top_k=None, **kwargs):
        # No parameters on this pipeline right now
        postprocess_params = {}
        if top_k is not None:
            if top_k > self.model.config.num_labels:
                top_k = self.model.config.num_labels
            postprocess_params["top_k"] = top_k
        return {}, {}, postprocess_params

    def preprocess(self, inputs):
148
        if isinstance(inputs, str):
149
150
151
152
153
154
155
            if inputs.startswith("http://") or inputs.startswith("https://"):
                # We need to actually check for a real protocol, otherwise it's impossible to use a local file
                # like http_huggingface_co.png
                inputs = requests.get(inputs).content
            else:
                with open(inputs, "rb") as f:
                    inputs = f.read()
156
157
158
159

        if isinstance(inputs, bytes):
            inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        if isinstance(inputs, dict):
            # Accepting `"array"` which is the key defined in `datasets` for
            # better integration
            if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
                raise ValueError(
                    "When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
                    '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
                    "containing the sampling_rate associated with that array"
                )

            _inputs = inputs.pop("raw", None)
            if _inputs is None:
                # Remove path which will not be used from `datasets`.
                inputs.pop("path", None)
                _inputs = inputs.pop("array", None)
            in_sampling_rate = inputs.pop("sampling_rate")
            inputs = _inputs
            if in_sampling_rate != self.feature_extractor.sampling_rate:
                import torch

                if is_torchaudio_available():
                    from torchaudio import functional as F
                else:
                    raise ImportError(
                        "torchaudio is required to resample audio samples in AudioClassificationPipeline. "
                        "The torchaudio package can be installed through: `pip install torchaudio`."
                    )

                inputs = F.resample(
                    torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
                ).numpy()

192
193
194
        if not isinstance(inputs, np.ndarray):
            raise ValueError("We expect a numpy ndarray as input")
        if len(inputs.shape) != 1:
195
            raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
196
197
198
199

        processed = self.feature_extractor(
            inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
        )
200
        return processed
201

202
203
204
    def _forward(self, model_inputs):
        model_outputs = self.model(**model_inputs)
        return model_outputs
205

206
207
208
    def postprocess(self, model_outputs, top_k=5):
        probs = model_outputs.logits[0].softmax(-1)
        scores, ids = probs.topk(top_k)
209

210
211
        scores = scores.tolist()
        ids = ids.tolist()
212
213
214
215

        labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]

        return labels