"library/vscode:/vscode.git/clone" did not exist on "5d0154529fa1e2431b13a74b4634ba2ba2308fa1"
Commit 9ab35c23 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

iterate on audio class

parent 4664d695
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli" \
--train_dataset_config_name "default+en_accented" \
--train_split_name "train+test" \
--eval_dataset_name "" \
--output_dir "./" \
--do_train \
--do_eval \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--max_length_seconds 10 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 32 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 32 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--seed 0 \
--push_to_hub
......@@ -37,7 +37,8 @@ from transformers import (
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils import check_min_version
from transformers.models.whisper.tokenization_whisper import LANGUAGES
logger = logging.getLogger(__name__)
......@@ -46,7 +47,7 @@ logger = logging.getLogger(__name__)
check_min_version("4.38.0.dev0")
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000) -> np.ndarray:
"""Randomly sample chunks of `max_length` seconds from the input audio"""
sample_length = int(round(sample_rate * max_length))
if len(wav) <= sample_length:
......@@ -55,6 +56,21 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
return wav[random_offset : random_offset + sample_length]
def preprocess_labels(labels: List[str]) -> List[str]:
"""Apply pre-processing formatting to the accent labels"""
processed_labels = []
for label in labels:
if "_" in label:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code = label.split("_")[-1]
label = LANGUAGES[language_code]
if label == "British":
# 1 speaker in VCTK is labelled as British instead of English - let's normalise
label = "English"
processed_labels.append(label.capitalize())
return processed_labels
@dataclass
class DataTrainingArguments:
"""
......@@ -79,6 +95,12 @@ class DataTrainingArguments:
"multiple datasets by separating dataset configs by a '+' symbol."
},
)
train_split_name: str = field(
default="train",
metadata={
"help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
},
)
train_dataset_samples: str = field(
default=None,
metadata={
......@@ -98,6 +120,15 @@ class DataTrainingArguments:
"help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
},
)
eval_split_name: str = field(
default="validation",
metadata={
"help": (
"The name of the evaluation data set split to use (via the datasets"
" library). Defaults to 'validation'"
)
},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
......@@ -200,12 +231,6 @@ def convert_dataset_str_to_list(
):
if isinstance(dataset_names, str):
dataset_names = dataset_names.split("+")
# we assume that all the datasets we're using derive from the distil-whisper org on the Hub - prepend the org name if necessary
for i in range(len(dataset_names)):
ds_name = dataset_names[i]
dataset_names[i] = f"distil-whisper/{ds_name}" if "/" not in ds_name else ds_name
dataset_config_names = dataset_config_names.split("+")
splits = splits.split("+") if splits is not None else None
label_column_names = label_column_names.split("+") if label_column_names is not None else None
......@@ -345,10 +370,6 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_audio_classification", model_args, data_args)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......@@ -410,8 +431,6 @@ def main():
trust_remote_code=data_args.trust_remote_code,
)
raw_datasets_train_features = raw_datasets["train"].features.keys()
if training_args.do_eval:
dataset_names_dict = convert_dataset_str_to_list(
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
......@@ -452,7 +471,7 @@ def main():
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
f"--label_column_name {data_args.eval_label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
......@@ -498,6 +517,7 @@ def main():
subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = preprocess_labels(batch["labels"])
return output_batch
def val_transforms(batch):
......@@ -505,11 +525,12 @@ def main():
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = preprocess_labels(batch["labels"])
return output_batch
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"].features[data_args.label_column_name].names
labels = raw_datasets["train"]["label"]
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment