Commit 7cbf4d55 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

debugging

parent 9b7b518e
......@@ -2,10 +2,14 @@
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli" \
--train_dataset_config_name "default+en_accented" \
--train_split_name "train+test" \
--eval_dataset_name "" \
--train_dataset_name "vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "default+en_accented+default" \
--train_split_name "train+test+validation" \
--train_label_column_name "accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--output_dir "./" \
--do_train \
--do_eval \
......@@ -13,12 +17,12 @@ python run_audio_classification.py \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--max_length_seconds 10 \
--min_length_seconds 5 \
--max_length_seconds 20 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 32 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 32 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
......@@ -29,4 +33,5 @@ python run_audio_classification.py \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--seed 0 \
--push_to_hub
--push_to_hub \
--trust_remote_code
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "hf-internal-testing/tiny-random-wav2vec2" \
--train_dataset_name "facebook/voxpopuli" \
--train_dataset_config_name "en_accented" \
--train_split_name "test" \
--train_label_column_name "accent" \
--eval_dataset_name "facebook/voxpopuli" \
--eval_dataset_config_name "en_accented" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--trust_remote_code \
--output_dir "./" \
--do_train \
--do_eval \
--max_train_samples 100 \
--max_eval_samples 100 \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--min_length_seconds 5 \
--max_length_seconds 10 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--dataloader_num_workers 0 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--seed 0
......@@ -16,6 +16,7 @@
import logging
import os
import re
import sys
from dataclasses import dataclass, field
from random import randint
......@@ -36,9 +37,9 @@ from transformers import (
TrainingArguments,
set_seed,
)
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.models.whisper.tokenization_whisper import LANGUAGES
logger = logging.getLogger(__name__)
......@@ -56,19 +57,20 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
return wav[random_offset : random_offset + sample_length]
def preprocess_labels(labels: List[str]) -> List[str]:
def preprocess_labels(label: str) -> str:
"""Apply pre-processing formatting to the accent labels"""
processed_labels = []
for label in labels:
if "_" in label:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code = label.split("_")[-1]
label = LANGUAGES[language_code]
if label == "British":
# 1 speaker in VCTK is labelled as British instead of English - let's normalise
label = "English"
processed_labels.append(label.capitalize())
return processed_labels
if "_" in label:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code = label.split("_")[-1]
label = LANGUAGES[language_code]
if label == "British":
# 1 speaker in VCTK is labelled as British instead of English - let's normalise
label = "English"
# VCTK labels for two words are concatenated into one (NewZeleand-> New Zealand)
label = re.sub(r"(\w)([A-Z])", r"\1 \2", label)
# convert Whisper language code (polish) to capitalised (Polish)
label = label.capitalize()
return label
@dataclass
......@@ -161,10 +163,18 @@ class DataTrainingArguments:
)
},
)
min_length_seconds: float = field(
default=5,
metadata={"help": "Audio clips less than this value will be filtered during training."},
)
max_length_seconds: float = field(
default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
@dataclass
......@@ -326,7 +336,7 @@ def load_multiple_datasets(
if dataset_dict["label_column_name"] not in dataset_features:
raise ValueError(
f"Label column name {dataset_dict['text_column_name']} not found in dataset"
f"Label column name {dataset_dict['label_column_name']} not found in dataset"
f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the"
f" correct text column - one of {', '.join(dataset_features)}."
)
......@@ -423,12 +433,12 @@ def main():
data_args.train_dataset_config_name,
splits=data_args.train_split_name,
label_column_names=data_args.train_label_column_name,
streaming=data_args.streaming,
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=data_args.dataset_cache_dir,
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=data_args.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if training_args.do_eval:
......@@ -449,10 +459,10 @@ def main():
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
streaming=data_args.streaming,
trust_remote_code=data_args.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming,
)
else:
# load multiple eval sets
......@@ -463,10 +473,10 @@ def main():
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
cache_dir=model_args.cache_dir,
token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
trust_remote_code=data_args.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming,
)
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
......@@ -505,36 +515,70 @@ def main():
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
model_input_name = feature_extractor.model_input_names[0]
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
for audio in batch[data_args.audio_column_name]:
wav = random_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
if training_args.do_eval:
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = preprocess_labels(batch["labels"])
return output_batch
def val_transforms(batch):
"""Apply val_transforms across a batch."""
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = preprocess_labels(batch["labels"])
return output_batch
sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0]
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate
def prepare_dataset(sample):
audio = sample["audio"]["array"]
if len(audio) / sampling_rate > max_input_length:
audio = random_subsample(audio, max_input_length, sampling_rate)
inputs = feature_extractor(audio, sampling_rate=sampling_rate)
sample[model_input_name] = inputs.get(model_input_name)
sample["input_length"] = len(audio) / sampling_rate
sample["labels"] = preprocess_labels(sample["labels"])
return sample
vectorized_datasets = raw_datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preprocess dataset"
)
# filter training data with inputs longer than max_input_length
def is_audio_in_length_range(length):
return min_input_length < length < max_input_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
input_columns=["input_length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non valid labels
def is_label_valid(label):
return label != "Unknown"
vectorized_datasets = vectorized_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"]["label"]
label2id, id2label = {}, {}
labels = vectorized_datasets["train"]["labels"]
label2id, id2label, num_label = {}, {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
num_label[label] += 1
if label not in label2id:
label2id[label] = str(i)
id2label[str(i)] = label
logger.info(f"Number of labels: {num_label}")
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
......@@ -572,32 +616,12 @@ def main():
if model_args.freeze_feature_encoder:
model.freeze_feature_encoder()
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
raw_datasets["train"].set_transform(
train_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
if training_args.do_eval:
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
raw_datasets["eval"].set_transform(
val_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
# Initialize our trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment