Unverified Commit 860ea8a5 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `audio-classification` example in the doc. (#20235)

* Adding `audio-classification` example in the doc.

* Adding `>>>` to get the real test.

* Removing assert.

* Fixup.
parent a00b7e85
...@@ -16,6 +16,8 @@ from typing import Union ...@@ -16,6 +16,8 @@ from typing import Union
import numpy as np import numpy as np
import requests
from ..utils import add_end_docstrings, is_torch_available, logging from ..utils import add_end_docstrings, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
...@@ -69,6 +71,24 @@ class AudioClassificationPipeline(Pipeline): ...@@ -69,6 +71,24 @@ class AudioClassificationPipeline(Pipeline):
raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio
formats. formats.
Example:
```python
>>> from transformers import pipeline
>>> classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
>>> result = classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
>>> # Simplify results, different torch versions might alter the scores slightly.
>>> from transformers.testing_utils import nested_simplify
>>> nested_simplify(result)
[{'score': 0.997, 'label': '_unknown_'}, {'score': 0.002, 'label': 'left'}, {'score': 0.0, 'label': 'yes'}, {'score': 0.0, 'label': 'down'}, {'score': 0.0, 'label': 'stop'}]
```
[Using pipelines in a webserver or with a dataset](../pipeline_tutorial)
This pipeline can currently be loaded from [`pipeline`] using the following task identifier: This pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"audio-classification"`. `"audio-classification"`.
...@@ -126,8 +146,13 @@ class AudioClassificationPipeline(Pipeline): ...@@ -126,8 +146,13 @@ class AudioClassificationPipeline(Pipeline):
def preprocess(self, inputs): def preprocess(self, inputs):
if isinstance(inputs, str): if isinstance(inputs, str):
with open(inputs, "rb") as f: if inputs.startswith("http://") or inputs.startswith("https://"):
inputs = f.read() # We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(inputs).content
else:
with open(inputs, "rb") as f:
inputs = f.read()
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment