Unverified Commit dbaf4920 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Examples] Use Audio feature in speech classification (#14052)

* Update SEW integration test tolerance

* Update audio classification

* Update test

* Remove torchaudio

* Add dataset revision

* Hub branch naming

* Revert dataset revisions

* Update datasets
parent 3fefa292
...@@ -68,7 +68,7 @@ The following command shows how to fine-tune [wav2vec2-base](https://huggingface ...@@ -68,7 +68,7 @@ The following command shows how to fine-tune [wav2vec2-base](https://huggingface
```bash ```bash
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path facebook/wav2vec2-base \ --model_name_or_path facebook/wav2vec2-base \
--dataset_name anton-l/common_language \ --dataset_name common_language \
--audio_column_name path \ --audio_column_name path \
--label_column_name language \ --label_column_name language \
--output_dir wav2vec2-base-lang-id \ --output_dir wav2vec2-base-lang-id \
......
datasets>=1.12.0 datasets>=1.14.0
librosa
torchaudio torchaudio
torch>=1.6 torch>=1.6
\ No newline at end of file
...@@ -22,7 +22,6 @@ from typing import Optional ...@@ -22,7 +22,6 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
import torchaudio
from datasets import DatasetDict, load_dataset from datasets import DatasetDict, load_dataset
import transformers import transformers
...@@ -43,19 +42,9 @@ from transformers.utils.versions import require_version ...@@ -43,19 +42,9 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.11.0.dev0") check_min_version("4.12.0.dev0")
require_version("datasets>=1.12.1", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
def load_audio(path: str, sample_rate: int = 16000):
wav, sr = torchaudio.load(path)
# convert multi-channel audio to mono
wav = wav.mean(0)
# standardize sample rate if it varies in the dataset
resampler = torchaudio.transforms.Resample(sr, sample_rate)
wav = resampler(wav)
return wav
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000): def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
...@@ -100,8 +89,8 @@ class DataTrainingArguments: ...@@ -100,8 +89,8 @@ class DataTrainingArguments:
}, },
) )
audio_column_name: Optional[str] = field( audio_column_name: Optional[str] = field(
default="file", default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'file'"}, metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
) )
label_column_name: Optional[str] = field( label_column_name: Optional[str] = field(
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"} default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
...@@ -246,13 +235,18 @@ def main(): ...@@ -246,13 +235,18 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
# `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
def train_transforms(batch): def train_transforms(batch):
"""Apply train_transforms across a batch.""" """Apply train_transforms across a batch."""
output_batch = {"input_values": []} output_batch = {"input_values": []}
for f in batch[data_args.audio_column_name]: for audio in batch[data_args.audio_column_name]:
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
wav = random_subsample( wav = random_subsample(
wav, max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
) )
output_batch["input_values"].append(wav) output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]] output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
...@@ -262,8 +256,8 @@ def main(): ...@@ -262,8 +256,8 @@ def main():
def val_transforms(batch): def val_transforms(batch):
"""Apply val_transforms across a batch.""" """Apply val_transforms across a batch."""
output_batch = {"input_values": []} output_batch = {"input_values": []}
for f in batch[data_args.audio_column_name]: for audio in batch[data_args.audio_column_name]:
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate) wav = audio["array"]
output_batch["input_values"].append(wav) output_batch["input_values"].append(wav)
output_batch["labels"] = [label for label in batch[data_args.label_column_name]] output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
...@@ -311,8 +305,6 @@ def main(): ...@@ -311,8 +305,6 @@ def main():
model.freeze_feature_extractor() model.freeze_feature_extractor()
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
raw_datasets["train"] = ( raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
...@@ -321,8 +313,6 @@ def main(): ...@@ -321,8 +313,6 @@ def main():
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False) raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
if training_args.do_eval: if training_args.do_eval:
if "eval" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
raw_datasets["eval"] = ( raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
......
...@@ -113,7 +113,7 @@ def parse_args(): ...@@ -113,7 +113,7 @@ def parse_args():
"--audio_column_name", "--audio_column_name",
type=str, type=str,
default="audio", default="audio",
help="Column in the dataset that contains speech file path. Defaults to 'file'", help="Column in the dataset that contains speech file path. Defaults to 'audio'",
) )
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
...@@ -431,9 +431,9 @@ def main(): ...@@ -431,9 +431,9 @@ def main():
# via the `feature_extractor` # via the `feature_extractor`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
# make sure that dataset decodes audio with correct samlping rate # make sure that dataset decodes audio with correct sampling rate
raw_datasets = raw_datasets.cast_column( raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
) )
# only normalized-inputs-training is supported # only normalized-inputs-training is supported
......
...@@ -454,9 +454,9 @@ def main(): ...@@ -454,9 +454,9 @@ def main():
# so that we just need to set the correct target sampling rate and normalize the input # so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor` # via the `feature_extractor`
# make sure that dataset decodes audio with correct samlping rate # make sure that dataset decodes audio with correct sampling rate
raw_datasets = raw_datasets.cast_column( raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
) )
# derive max & min input length for sample rate & max duration # derive max & min input length for sample rate & max duration
......
...@@ -428,7 +428,7 @@ class ExamplesTests(TestCasePlus): ...@@ -428,7 +428,7 @@ class ExamplesTests(TestCasePlus):
--dataset_config_name ks --dataset_config_name ks
--train_split_name test --train_split_name test
--eval_split_name test --eval_split_name test
--audio_column_name file --audio_column_name audio
--label_column_name label --label_column_name label
--do_train --do_train
--do_eval --do_eval
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment