Unverified Commit 37c5759c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Speech Examples] Add new audio feature (#14027)

* finish

* up

* finish all

* up
parent cde0c750
...@@ -13,7 +13,7 @@ streamlit ...@@ -13,7 +13,7 @@ streamlit
elasticsearch elasticsearch
nltk nltk
pandas pandas
datasets >= 1.1.3 datasets >= 1.13.3
fire fire
pytest pytest
conllu conllu
...@@ -21,3 +21,4 @@ sentencepiece != 0.1.92 ...@@ -21,3 +21,4 @@ sentencepiece != 0.1.92
protobuf protobuf
torchvision torchvision
jiwer jiwer
librosa
...@@ -94,7 +94,7 @@ To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv6 ...@@ -94,7 +94,7 @@ To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv6
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run: on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
```bash ```bash
accelerate launch run_pretrain_no_trainer.py \ accelerate launch run_wav2vec2_pretraining_no_trainer.py \
--dataset_name=librispeech_asr \ --dataset_name=librispeech_asr \
--dataset_config_names clean clean other \ --dataset_config_names clean clean other \
--dataset_split_names train.100 train.360 train.500 \ --dataset_split_names train.100 train.360 train.500 \
......
...@@ -2,3 +2,4 @@ datasets >= 1.12.0 ...@@ -2,3 +2,4 @@ datasets >= 1.12.0
torch >= 1.5 torch >= 1.5
torchaudio torchaudio
accelerate >= 0.5.0 accelerate >= 0.5.0
librosa
...@@ -25,7 +25,6 @@ from typing import Dict, List, Optional, Union ...@@ -25,7 +25,6 @@ from typing import Dict, List, Optional, Union
import datasets import datasets
import torch import torch
import torchaudio
from datasets import DatasetDict, concatenate_datasets, load_dataset from datasets import DatasetDict, concatenate_datasets, load_dataset
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -113,7 +112,7 @@ def parse_args(): ...@@ -113,7 +112,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--audio_column_name", "--audio_column_name",
type=str, type=str,
default="file", default="audio",
help="Column in the dataset that contains speech file path. Defaults to 'file'", help="Column in the dataset that contains speech file path. Defaults to 'file'",
) )
parser.add_argument( parser.add_argument(
...@@ -128,6 +127,18 @@ def parse_args(): ...@@ -128,6 +127,18 @@ def parse_args():
default=None, default=None,
help="Pretrained config name or path if not the same as model_name", help="Pretrained config name or path if not the same as model_name",
) )
parser.add_argument(
"--train_cache_file_name",
type=str,
default=None,
help="Path to the train cached file name",
)
parser.add_argument(
"--validation_cache_file_name",
type=str,
default=None,
help="Path to the validation cached file name",
)
parser.add_argument( parser.add_argument(
"--per_device_train_batch_size", "--per_device_train_batch_size",
type=int, type=int,
...@@ -414,9 +425,17 @@ def main(): ...@@ -414,9 +425,17 @@ def main():
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples)) raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows)) raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
# 2. Preprocess audio: load, resample, normalize and truncate # 2. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
# make sure that dataset decodes audio with correct samlping rate
raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# only normalized-inputs-training is supported # only normalized-inputs-training is supported
if not feature_extractor.do_normalize: if not feature_extractor.do_normalize:
raise ValueError( raise ValueError(
...@@ -427,38 +446,40 @@ def main(): ...@@ -427,38 +446,40 @@ def main():
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate) max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate) min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
resampler = None
if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3":
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate)
def prepare_dataset(batch): def prepare_dataset(batch):
speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name]) sample = batch[args.audio_column_name]
speech_array = speech_array.squeeze()
# if necessary resample audio
if resampler is not None:
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
speech_array = resampler(speech_array)
sampling_rate = resampler.new_freq
speech_array = speech_array.numpy() inputs = feature_extractor(
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True) sample["array"], sampling_rate=sample["sampling_rate"], max_length=max_length, truncation=True
)
batch["input_values"] = inputs.input_values[0] batch["input_values"] = inputs.input_values[0]
batch["input_length"] = len(inputs.input_values[0])
return batch return batch
# load via mapped files via path
cache_file_names = None
if args.train_cache_file_name is not None:
cache_file_names = {"train": args.train_cache_file_name, "validation": args.validation_cache_file_name}
# load audio files into numpy arrays # load audio files into numpy arrays
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map( vectorized_datasets = raw_datasets.map(
prepare_dataset, prepare_dataset,
num_proc=args.preprocessing_num_workers, num_proc=args.preprocessing_num_workers,
remove_columns=raw_datasets["train"].column_names, remove_columns=raw_datasets["train"].column_names,
load_from_cache_file=not args.overwrite_cache, cache_file_names=cache_file_names,
)
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache
) )
if min_length > 0.0:
vectorized_datasets = vectorized_datasets.filter(
lambda x: x > min_length,
num_proc=args.preprocessing_num_workers,
input_columns=["input_length"],
)
vectorized_datasets = vectorized_datasets.remove_columns("input_length")
# for large datasets it is advised to run the preprocessing on a # for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely # single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode. # be a timeout when running the script in distributed mode.
......
...@@ -58,7 +58,6 @@ python run_speech_recognition_ctc.py \ ...@@ -58,7 +58,6 @@ python run_speech_recognition_ctc.py \
--learning_rate="3e-4" \ --learning_rate="3e-4" \
--warmup_steps="500" \ --warmup_steps="500" \
--evaluation_strategy="steps" \ --evaluation_strategy="steps" \
--audio_column_name="path" \
--text_column_name="sentence" \ --text_column_name="sentence" \
--save_steps="400" \ --save_steps="400" \
--eval_steps="100" \ --eval_steps="100" \
...@@ -87,7 +86,6 @@ python -m torch.distributed.launch \ ...@@ -87,7 +86,6 @@ python -m torch.distributed.launch \
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \ --model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
--dataset_config_name="tr" \ --dataset_config_name="tr" \
--output_dir="./wav2vec2-common_voice-tr-demo-dist" \ --output_dir="./wav2vec2-common_voice-tr-demo-dist" \
--preprocessing_num_workers="16" \
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs="15" \ --num_train_epochs="15" \
--per_device_train_batch_size="4" \ --per_device_train_batch_size="4" \
......
datasets >= 1.12.0 datasets >= 1.13.3
torch >= 1.5 torch >= 1.5
torchaudio torchaudio
librosa
...@@ -24,9 +24,9 @@ import sys ...@@ -24,9 +24,9 @@ import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import datasets
import numpy as np import numpy as np
import torch import torch
import torchaudio
from datasets import DatasetDict, load_dataset, load_metric from datasets import DatasetDict, load_dataset, load_metric
import transformers import transformers
...@@ -49,8 +49,7 @@ from transformers.utils.versions import require_version ...@@ -49,8 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0.dev0") check_min_version("4.12.0.dev0")
# TODO(Patrick) Bump up as soon as audio features are merged require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
require_version("datasets>=1.12.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -179,12 +178,12 @@ class DataTrainingArguments: ...@@ -179,12 +178,12 @@ class DataTrainingArguments:
min_duration_in_seconds: Optional[float] = field( min_duration_in_seconds: Optional[float] = field(
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
) )
only_data_preprocessing: Optional[bool] = field( preprocessing_only: Optional[bool] = field(
default=False, default=False,
metadata={ metadata={
"help": "Whether to only do data preprocessing and skip training. " "help": "Whether to only do data preprocessing and skip training. "
"This is especially useful when data preprocessing errors out in distributed training due to timeout. " "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
"In this case, one should run the preprocessing in a non-distributed setup with `only_data_preprocessing=True` " "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
"so that the cached datasets can consequently be loaded in distributed training" "so that the cached datasets can consequently be loaded in distributed training"
}, },
) )
...@@ -450,41 +449,30 @@ def main(): ...@@ -450,41 +449,30 @@ def main():
if model_args.freeze_feature_extractor: if model_args.freeze_feature_extractor:
model.freeze_feature_extractor() model.freeze_feature_extractor()
# 5. Now we preprocess the datasets which includes loading the audio, resampling and padding # 5. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# The following code should be cleaned up as soon as # make sure that dataset decodes audio with correct samlping rate
# https://github.com/huggingface/datasets/pull/2324 is merged raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
# Preprocessing the datasets. )
# We need to read the audio files as arrays and tokenize the targets.
# derive max & min input length for sample rate & max duration # derive max & min input length for sample rate & max duration
max_input_length = data_args.max_duration_in_seconds * processor.feature_extractor.sampling_rate max_input_length = data_args.max_duration_in_seconds * processor.feature_extractor.sampling_rate
min_input_length = data_args.min_duration_in_seconds * processor.feature_extractor.sampling_rate min_input_length = data_args.min_duration_in_seconds * processor.feature_extractor.sampling_rate
resampler = None
if raw_datasets["train"][data_args.audio_column_name][0].split(".")[-1] == "mp3":
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
resampler = torchaudio.transforms.Resample(48_000, processor.feature_extractor.sampling_rate)
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets. # We need to read the audio files as arrays and tokenize the targets.
def prepare_dataset(batch): def prepare_dataset(batch):
# load audio # load audio
speech_array, sampling_rate = torchaudio.load(batch[data_args.audio_column_name]) sample = batch[data_args.audio_column_name]
speech_array = speech_array.squeeze()
# if necessary resample audio
if resampler is not None:
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
speech_array = resampler(speech_array)
sampling_rate = resampler.new_freq
speech_array = speech_array.numpy()
batch["input_values"] = processor( batch["input_values"] = processor(
speech_array, sampling_rate=sampling_rate, truncate=True, max_length=max_input_length sample["array"], sampling_rate=sample["sampling_rate"], truncate=True, max_length=max_input_length
).input_values[0] ).input_values[0]
batch["input_length"] = len(batch["input_values"])
# Setup the processor for targets # Setup the processor for targets
with processor.as_target_processor(): with processor.as_target_processor():
...@@ -502,10 +490,13 @@ def main(): ...@@ -502,10 +490,13 @@ def main():
if min_input_length > 0.0: if min_input_length > 0.0:
# filter data that is shorter than min_input_length # filter data that is shorter than min_input_length
vectorized_datasets = vectorized_datasets.filter( vectorized_datasets = vectorized_datasets.filter(
lambda data: len(data["input_values"]) > min_input_length, lambda x: x > min_input_length,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
input_columns=["input_length"],
) )
vectorized_datasets = vectorized_datasets.remove_columns("input_length")
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
# Let's use word error rate (WER) as our evaluation metric, # Let's use word error rate (WER) as our evaluation metric,
# instantiate a data collator and the trainer # instantiate a data collator and the trainer
...@@ -513,8 +504,13 @@ def main(): ...@@ -513,8 +504,13 @@ def main():
# Define Metric during training # Define Metric during training
wer_metric = load_metric("wer") wer_metric = load_metric("wer")
if data_args.only_data_preprocessing: # for large datasets it is advised to run the preprocessing on a
logger.info("Data preprocessing finished.") # single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return return
def compute_metrics(pred): def compute_metrics(pred):
......
...@@ -395,7 +395,6 @@ class ExamplesTests(TestCasePlus): ...@@ -395,7 +395,6 @@ class ExamplesTests(TestCasePlus):
--dataset_config_name clean --dataset_config_name clean
--train_split_name validation --train_split_name validation
--eval_split_name validation --eval_split_name validation
--audio_column_name file
--do_train --do_train
--do_eval --do_eval
--learning_rate 1e-4 --learning_rate 1e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment