Commit 7cbf4d55 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

debugging

parent 9b7b518e
...@@ -2,10 +2,14 @@ ...@@ -2,10 +2,14 @@
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli" \ --train_dataset_name "vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "default+en_accented" \ --train_dataset_config_name "default+en_accented+default" \
--train_split_name "train+test" \ --train_split_name "train+test+validation" \
--eval_dataset_name "" \ --train_label_column_name "accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--output_dir "./" \ --output_dir "./" \
--do_train \ --do_train \
--do_eval \ --do_eval \
...@@ -13,12 +17,12 @@ python run_audio_classification.py \ ...@@ -13,12 +17,12 @@ python run_audio_classification.py \
--remove_unused_columns False \ --remove_unused_columns False \
--fp16 \ --fp16 \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--max_length_seconds 10 \ --min_length_seconds 5 \
--max_length_seconds 20 \
--attention_mask False \ --attention_mask False \
--warmup_ratio 0.1 \ --warmup_ratio 0.1 \
--num_train_epochs 5 \ --num_train_epochs 5 \
--per_device_train_batch_size 32 \ --per_device_train_batch_size 32 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 32 \ --per_device_eval_batch_size 32 \
--dataloader_num_workers 4 \ --dataloader_num_workers 4 \
--logging_strategy "steps" \ --logging_strategy "steps" \
...@@ -29,4 +33,5 @@ python run_audio_classification.py \ ...@@ -29,4 +33,5 @@ python run_audio_classification.py \
--metric_for_best_model "accuracy" \ --metric_for_best_model "accuracy" \
--save_total_limit 3 \ --save_total_limit 3 \
--seed 0 \ --seed 0 \
--push_to_hub --push_to_hub \
--trust_remote_code
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "hf-internal-testing/tiny-random-wav2vec2" \
--train_dataset_name "facebook/voxpopuli" \
--train_dataset_config_name "en_accented" \
--train_split_name "test" \
--train_label_column_name "accent" \
--eval_dataset_name "facebook/voxpopuli" \
--eval_dataset_config_name "en_accented" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--trust_remote_code \
--output_dir "./" \
--do_train \
--do_eval \
--max_train_samples 100 \
--max_eval_samples 100 \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--min_length_seconds 5 \
--max_length_seconds 10 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--dataloader_num_workers 0 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--seed 0
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import logging import logging
import os import os
import re
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from random import randint from random import randint
...@@ -36,9 +37,9 @@ from transformers import ( ...@@ -36,9 +37,9 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.models.whisper.tokenization_whisper import LANGUAGES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,19 +57,20 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600 ...@@ -56,19 +57,20 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
return wav[random_offset : random_offset + sample_length] return wav[random_offset : random_offset + sample_length]
def preprocess_labels(labels: List[str]) -> List[str]: def preprocess_labels(label: str) -> str:
"""Apply pre-processing formatting to the accent labels""" """Apply pre-processing formatting to the accent labels"""
processed_labels = [] if "_" in label:
for label in labels: # voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
if "_" in label: language_code = label.split("_")[-1]
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent label = LANGUAGES[language_code]
language_code = label.split("_")[-1] if label == "British":
label = LANGUAGES[language_code] # 1 speaker in VCTK is labelled as British instead of English - let's normalise
if label == "British": label = "English"
# 1 speaker in VCTK is labelled as British instead of English - let's normalise # VCTK labels for two words are concatenated into one (NewZeleand-> New Zealand)
label = "English" label = re.sub(r"(\w)([A-Z])", r"\1 \2", label)
processed_labels.append(label.capitalize()) # convert Whisper language code (polish) to capitalised (Polish)
return processed_labels label = label.capitalize()
return label
@dataclass @dataclass
...@@ -161,10 +163,18 @@ class DataTrainingArguments: ...@@ -161,10 +163,18 @@ class DataTrainingArguments:
) )
}, },
) )
min_length_seconds: float = field(
default=5,
metadata={"help": "Audio clips less than this value will be filtered during training."},
)
max_length_seconds: float = field( max_length_seconds: float = field(
default=20, default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."}, metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
) )
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
@dataclass @dataclass
...@@ -326,7 +336,7 @@ def load_multiple_datasets( ...@@ -326,7 +336,7 @@ def load_multiple_datasets(
if dataset_dict["label_column_name"] not in dataset_features: if dataset_dict["label_column_name"] not in dataset_features:
raise ValueError( raise ValueError(
f"Label column name {dataset_dict['text_column_name']} not found in dataset" f"Label column name {dataset_dict['label_column_name']} not found in dataset"
f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the" f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the"
f" correct text column - one of {', '.join(dataset_features)}." f" correct text column - one of {', '.join(dataset_features)}."
) )
...@@ -423,12 +433,12 @@ def main(): ...@@ -423,12 +433,12 @@ def main():
data_args.train_dataset_config_name, data_args.train_dataset_config_name,
splits=data_args.train_split_name, splits=data_args.train_split_name,
label_column_names=data_args.train_label_column_name, label_column_names=data_args.train_label_column_name,
streaming=data_args.streaming,
dataset_samples=data_args.train_dataset_samples, dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed, seed=training_args.seed,
cache_dir=data_args.dataset_cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
if training_args.do_eval: if training_args.do_eval:
...@@ -449,10 +459,10 @@ def main(): ...@@ -449,10 +459,10 @@ def main():
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
split=dataset_dict["split"], split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
streaming=data_args.streaming, trust_remote_code=model_args.trust_remote_code,
trust_remote_code=data_args.trust_remote_code, # streaming=data_args.streaming,
) )
else: else:
# load multiple eval sets # load multiple eval sets
...@@ -463,10 +473,10 @@ def main(): ...@@ -463,10 +473,10 @@ def main():
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
split=dataset_dict["split"], split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.use_auth_token else None, token=True if model_args.use_auth_token else None,
streaming=data_args.streaming, trust_remote_code=model_args.trust_remote_code,
trust_remote_code=data_args.trust_remote_code, # streaming=data_args.streaming,
) )
features = raw_datasets[pretty_name].features.keys() features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features: if dataset_dict["label_column_name"] not in features:
...@@ -505,36 +515,70 @@ def main(): ...@@ -505,36 +515,70 @@ def main():
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
) )
model_input_name = feature_extractor.model_input_names[0] if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
def train_transforms(batch): if training_args.do_eval:
"""Apply train_transforms across a batch.""" if data_args.max_eval_samples is not None:
subsampled_wavs = [] raw_datasets["eval"] = (
for audio in batch[data_args.audio_column_name]: raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
wav = random_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
) )
subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate) sampling_rate = feature_extractor.sampling_rate
output_batch = {model_input_name: inputs.get(model_input_name)} model_input_name = feature_extractor.model_input_names[0]
output_batch["labels"] = preprocess_labels(batch["labels"]) max_input_length = data_args.max_length_seconds * sampling_rate
return output_batch min_input_length = data_args.min_length_seconds * sampling_rate
def val_transforms(batch): def prepare_dataset(sample):
"""Apply val_transforms across a batch.""" audio = sample["audio"]["array"]
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]] if len(audio) / sampling_rate > max_input_length:
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate) audio = random_subsample(audio, max_input_length, sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)} inputs = feature_extractor(audio, sampling_rate=sampling_rate)
output_batch["labels"] = preprocess_labels(batch["labels"]) sample[model_input_name] = inputs.get(model_input_name)
return output_batch sample["input_length"] = len(audio) / sampling_rate
sample["labels"] = preprocess_labels(sample["labels"])
return sample
vectorized_datasets = raw_datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preprocess dataset"
)
# filter training data with inputs longer than max_input_length
def is_audio_in_length_range(length):
return min_input_length < length < max_input_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
input_columns=["input_length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non valid labels
def is_label_valid(label):
return label != "Unknown"
vectorized_datasets = vectorized_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Prepare label mappings. # Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"]["label"] labels = vectorized_datasets["train"]["labels"]
label2id, id2label = {}, {} label2id, id2label, num_label = {}, {}, {}
for i, label in enumerate(labels): for i, label in enumerate(labels):
label2id[label] = str(i) num_label[label] += 1
id2label[str(i)] = label if label not in label2id:
label2id[label] = str(i)
id2label[str(i)] = label
logger.info(f"Number of labels: {num_label}")
# Load the accuracy metric from the datasets package # Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
...@@ -572,32 +616,12 @@ def main(): ...@@ -572,32 +616,12 @@ def main():
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
model.freeze_feature_encoder() model.freeze_feature_encoder()
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
raw_datasets["train"].set_transform(
train_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
if training_args.do_eval:
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
raw_datasets["eval"].set_transform(
val_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=raw_datasets["train"] if training_args.do_train else None, train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=feature_extractor, tokenizer=feature_extractor,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment