Unverified Commit dbaf4920 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Examples] Use Audio feature in speech classification (#14052)

* Update SEW integration test tolerance

* Update audio classification

* Update test

* Remove torchaudio

* Add dataset revision

* Hub branch naming

* Revert dataset revisions

* Update datasets
parent 3fefa292
......@@ -68,7 +68,7 @@ The following command shows how to fine-tune [wav2vec2-base](https://huggingface
```bash
python run_audio_classification.py \
--model_name_or_path facebook/wav2vec2-base \
--dataset_name anton-l/common_language \
--dataset_name common_language \
--audio_column_name path \
--label_column_name language \
--output_dir wav2vec2-base-lang-id \
......
datasets>=1.12.0
datasets>=1.14.0
librosa
torchaudio
torch>=1.6
\ No newline at end of file
......@@ -22,7 +22,6 @@ from typing import Optional
import datasets
import numpy as np
import torchaudio
from datasets import DatasetDict, load_dataset
import transformers
......@@ -43,19 +42,9 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.11.0.dev0")
check_min_version("4.12.0.dev0")
require_version("datasets>=1.12.1", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
def load_audio(path: str, sample_rate: int = 16000):
wav, sr = torchaudio.load(path)
# convert multi-channel audio to mono
wav = wav.mean(0)
# standardize sample rate if it varies in the dataset
resampler = torchaudio.transforms.Resample(sr, sample_rate)
wav = resampler(wav)
return wav
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
......@@ -100,8 +89,8 @@ class DataTrainingArguments:
},
)
audio_column_name: Optional[str] = field(
default="file",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'file'"},
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
label_column_name: Optional[str] = field(
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
......@@ -246,13 +235,18 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
# `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
def train_transforms(batch):
"""Apply train_transforms across a batch."""
output_batch = {"input_values": []}
for f in batch[data_args.audio_column_name]:
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
for audio in batch[data_args.audio_column_name]:
wav = random_subsample(
wav, max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
......@@ -262,8 +256,8 @@ def main():
def val_transforms(batch):
"""Apply val_transforms across a batch."""
output_batch = {"input_values": []}
for f in batch[data_args.audio_column_name]:
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
for audio in batch[data_args.audio_column_name]:
wav = audio["array"]
output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
......@@ -311,8 +305,6 @@ def main():
model.freeze_feature_extractor()
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
......@@ -321,8 +313,6 @@ def main():
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
if training_args.do_eval:
if "eval" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
......
......@@ -113,7 +113,7 @@ def parse_args():
"--audio_column_name",
type=str,
default="audio",
help="Column in the dataset that contains speech file path. Defaults to 'file'",
help="Column in the dataset that contains speech file path. Defaults to 'audio'",
)
parser.add_argument(
"--model_name_or_path",
......@@ -431,9 +431,9 @@ def main():
# via the `feature_extractor`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
# make sure that dataset decodes audio with correct samlping rate
# make sure that dataset decodes audio with correct sampling rate
raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# only normalized-inputs-training is supported
......
......@@ -454,9 +454,9 @@ def main():
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# make sure that dataset decodes audio with correct samlping rate
# make sure that dataset decodes audio with correct sampling rate
raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# derive max & min input length for sample rate & max duration
......
......@@ -428,7 +428,7 @@ class ExamplesTests(TestCasePlus):
--dataset_config_name ks
--train_split_name test
--eval_split_name test
--audio_column_name file
--audio_column_name audio
--label_column_name label
--do_train
--do_eval
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment