Commit 7cbf4d55 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

debugging

parent 9b7b518e
...@@ -2,10 +2,14 @@ ...@@ -2,10 +2,14 @@
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "vctk+facebook/voxpopuli" \ --train_dataset_name "vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "default+en_accented" \ --train_dataset_config_name "default+en_accented+default" \
--train_split_name "train+test" \ --train_split_name "train+test+validation" \
--eval_dataset_name "" \ --train_label_column_name "accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--output_dir "./" \ --output_dir "./" \
--do_train \ --do_train \
--do_eval \ --do_eval \
...@@ -13,12 +17,12 @@ python run_audio_classification.py \ ...@@ -13,12 +17,12 @@ python run_audio_classification.py \
--remove_unused_columns False \ --remove_unused_columns False \
--fp16 \ --fp16 \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--max_length_seconds 10 \ --min_length_seconds 5 \
--max_length_seconds 20 \
--attention_mask False \ --attention_mask False \
--warmup_ratio 0.1 \ --warmup_ratio 0.1 \
--num_train_epochs 5 \ --num_train_epochs 5 \
--per_device_train_batch_size 32 \ --per_device_train_batch_size 32 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 32 \ --per_device_eval_batch_size 32 \
--dataloader_num_workers 4 \ --dataloader_num_workers 4 \
--logging_strategy "steps" \ --logging_strategy "steps" \
...@@ -29,4 +33,5 @@ python run_audio_classification.py \ ...@@ -29,4 +33,5 @@ python run_audio_classification.py \
--metric_for_best_model "accuracy" \ --metric_for_best_model "accuracy" \
--save_total_limit 3 \ --save_total_limit 3 \
--seed 0 \ --seed 0 \
--push_to_hub --push_to_hub \
--trust_remote_code
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "hf-internal-testing/tiny-random-wav2vec2" \
--train_dataset_name "facebook/voxpopuli" \
--train_dataset_config_name "en_accented" \
--train_split_name "test" \
--train_label_column_name "accent" \
--eval_dataset_name "facebook/voxpopuli" \
--eval_dataset_config_name "en_accented" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--trust_remote_code \
--output_dir "./" \
--do_train \
--do_eval \
--max_train_samples 100 \
--max_eval_samples 100 \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--learning_rate 1e-4 \
--min_length_seconds 5 \
--max_length_seconds 10 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--dataloader_num_workers 0 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--seed 0
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import logging import logging
import os import os
import re
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from random import randint from random import randint
...@@ -36,9 +37,9 @@ from transformers import ( ...@@ -36,9 +37,9 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.models.whisper.tokenization_whisper import LANGUAGES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,10 +57,8 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600 ...@@ -56,10 +57,8 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
return wav[random_offset : random_offset + sample_length] return wav[random_offset : random_offset + sample_length]
def preprocess_labels(labels: List[str]) -> List[str]: def preprocess_labels(label: str) -> str:
"""Apply pre-processing formatting to the accent labels""" """Apply pre-processing formatting to the accent labels"""
processed_labels = []
for label in labels:
if "_" in label: if "_" in label:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent # voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code = label.split("_")[-1] language_code = label.split("_")[-1]
...@@ -67,8 +66,11 @@ def preprocess_labels(labels: List[str]) -> List[str]: ...@@ -67,8 +66,11 @@ def preprocess_labels(labels: List[str]) -> List[str]:
if label == "British": if label == "British":
# 1 speaker in VCTK is labelled as British instead of English - let's normalise # 1 speaker in VCTK is labelled as British instead of English - let's normalise
label = "English" label = "English"
processed_labels.append(label.capitalize()) # VCTK labels for two words are concatenated into one (NewZeleand-> New Zealand)
return processed_labels label = re.sub(r"(\w)([A-Z])", r"\1 \2", label)
# convert Whisper language code (polish) to capitalised (Polish)
label = label.capitalize()
return label
@dataclass @dataclass
...@@ -161,10 +163,18 @@ class DataTrainingArguments: ...@@ -161,10 +163,18 @@ class DataTrainingArguments:
) )
}, },
) )
min_length_seconds: float = field(
default=5,
metadata={"help": "Audio clips less than this value will be filtered during training."},
)
max_length_seconds: float = field( max_length_seconds: float = field(
default=20, default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."}, metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
) )
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
@dataclass @dataclass
...@@ -326,7 +336,7 @@ def load_multiple_datasets( ...@@ -326,7 +336,7 @@ def load_multiple_datasets(
if dataset_dict["label_column_name"] not in dataset_features: if dataset_dict["label_column_name"] not in dataset_features:
raise ValueError( raise ValueError(
f"Label column name {dataset_dict['text_column_name']} not found in dataset" f"Label column name {dataset_dict['label_column_name']} not found in dataset"
f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the" f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the"
f" correct text column - one of {', '.join(dataset_features)}." f" correct text column - one of {', '.join(dataset_features)}."
) )
...@@ -423,12 +433,12 @@ def main(): ...@@ -423,12 +433,12 @@ def main():
data_args.train_dataset_config_name, data_args.train_dataset_config_name,
splits=data_args.train_split_name, splits=data_args.train_split_name,
label_column_names=data_args.train_label_column_name, label_column_names=data_args.train_label_column_name,
streaming=data_args.streaming,
dataset_samples=data_args.train_dataset_samples, dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed, seed=training_args.seed,
cache_dir=data_args.dataset_cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
if training_args.do_eval: if training_args.do_eval:
...@@ -449,10 +459,10 @@ def main(): ...@@ -449,10 +459,10 @@ def main():
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
split=dataset_dict["split"], split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
streaming=data_args.streaming, trust_remote_code=model_args.trust_remote_code,
trust_remote_code=data_args.trust_remote_code, # streaming=data_args.streaming,
) )
else: else:
# load multiple eval sets # load multiple eval sets
...@@ -463,10 +473,10 @@ def main(): ...@@ -463,10 +473,10 @@ def main():
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
split=dataset_dict["split"], split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.use_auth_token else None, token=True if model_args.use_auth_token else None,
streaming=data_args.streaming, trust_remote_code=model_args.trust_remote_code,
trust_remote_code=data_args.trust_remote_code, # streaming=data_args.streaming,
) )
features = raw_datasets[pretty_name].features.keys() features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features: if dataset_dict["label_column_name"] not in features:
...@@ -505,37 +515,71 @@ def main(): ...@@ -505,37 +515,71 @@ def main():
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
) )
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
if training_args.do_eval:
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0] model_input_name = feature_extractor.model_input_names[0]
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate
def train_transforms(batch): def prepare_dataset(sample):
"""Apply train_transforms across a batch.""" audio = sample["audio"]["array"]
subsampled_wavs = [] if len(audio) / sampling_rate > max_input_length:
for audio in batch[data_args.audio_column_name]: audio = random_subsample(audio, max_input_length, sampling_rate)
wav = random_subsample( inputs = feature_extractor(audio, sampling_rate=sampling_rate)
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate sample[model_input_name] = inputs.get(model_input_name)
) sample["input_length"] = len(audio) / sampling_rate
subsampled_wavs.append(wav) sample["labels"] = preprocess_labels(sample["labels"])
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate) return sample
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = preprocess_labels(batch["labels"]) vectorized_datasets = raw_datasets.map(
return output_batch prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preprocess dataset"
)
def val_transforms(batch):
"""Apply val_transforms across a batch.""" # filter training data with inputs longer than max_input_length
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]] def is_audio_in_length_range(length):
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate) return min_input_length < length < max_input_length
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = preprocess_labels(batch["labels"]) vectorized_datasets = vectorized_datasets.filter(
return output_batch is_audio_in_length_range,
input_columns=["input_length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non valid labels
def is_label_valid(label):
return label != "Unknown"
vectorized_datasets = vectorized_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Prepare label mappings. # Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = raw_datasets["train"]["label"] labels = vectorized_datasets["train"]["labels"]
label2id, id2label = {}, {} label2id, id2label, num_label = {}, {}, {}
for i, label in enumerate(labels): for i, label in enumerate(labels):
num_label[label] += 1
if label not in label2id:
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
logger.info(f"Number of labels: {num_label}")
# Load the accuracy metric from the datasets package # Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
...@@ -572,32 +616,12 @@ def main(): ...@@ -572,32 +616,12 @@ def main():
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
model.freeze_feature_encoder() model.freeze_feature_encoder()
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
raw_datasets["train"].set_transform(
train_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
if training_args.do_eval:
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
raw_datasets["eval"].set_transform(
val_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=raw_datasets["train"] if training_args.do_train else None, train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=feature_extractor, tokenizer=feature_extractor,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment