Commit 0d5d9970 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

concat classification

parent b7b225a4
......@@ -2,36 +2,37 @@
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "default+en_accented+default" \
--train_split_name "train+test+validation" \
--train_label_column_name "accent+accent+accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \
--train_dataset_name "stable-speech/concatenated-accent-dataset" \
--train_dataset_config_name "default" \
--train_split_name "train" \
--train_label_column_name "labels" \
--eval_dataset_name "stable-speech/concatenated-accent-dataset" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--eval_label_column_name "labels" \
--output_dir "./" \
--do_train \
--do_eval \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--fp16_full_eval \
--learning_rate 1e-4 \
--max_length_seconds 20 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--min_length_seconds 5 \
--attention_mask \
--warmup_steps 100 \
--max_steps 1000 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--preprocessing_num_workers 16 \
--preprocessing_num_workers 4 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--freeze_base_model \
--push_to_hub \
--evaluation_strategy "steps" \
--eval_steps 500 \
--save_strategy "steps" \
--save_steps 1000 \
--freeze_base_model False \
--push_to_hub False \
--trust_remote_code
......@@ -17,25 +17,23 @@ metric:
name: eval/accuracy
parameters:
model_name_or_path:
values:
- facebook/mms-lid-126
- openai/whisper-large-v3
value: facebook/mms-lid-126
train_dataset_name:
value: sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc
value: stable-speech/concatenated-accent-dataset
train_dataset_config_name:
value: default+en_accented+default
value: default
train_split_name:
value: train+test+validation
value: train
train_label_column_name:
value: accent+accent+accent
value: labels
eval_dataset_name:
value: sanchit-gandhi/edacc
value: stable-speech/concatenated-accent-dataset
eval_dataset_config_name:
value: default
eval_split_name:
value: test
eval_label_column_name:
value: accent
value: labels
output_dir:
value: ./
remove_unused_columns:
......@@ -45,13 +43,13 @@ parameters:
lr_scheduler_type:
value: constant_with_warmup
max_length_seconds:
value: 10 # give some data diversity for longer audio samples
value: 20 # give some data diversity for longer audio samples
min_length_seconds:
value: 5
value: 7
attention_mask:
value: false
value: true
warmup_steps:
value: 50
value: 100
max_steps:
value: 2000
per_device_train_batch_size:
......@@ -59,7 +57,7 @@ parameters:
per_device_eval_batch_size:
value: 16
preprocessing_num_workers:
value: 16
value: 4
dataloader_num_workers:
value: 4
logging_strategy:
......@@ -69,7 +67,7 @@ parameters:
evaluation_strategy:
value: steps
eval_steps:
value: 2000
value: 1000
save_strategy:
value: steps
save_steps:
......@@ -77,7 +75,11 @@ parameters:
metric_for_best_model:
value: accuracy
freeze_base_model:
value: false
values:
- false
- true
group_by_length:
value: false # TODO(SG): batch by length
push_to_hub:
value: false
program: run_audio_classification.py
......
......@@ -57,6 +57,13 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
random_offset = randint(0, len(wav) - sample_length - 1)
return wav[random_offset : random_offset + sample_length]
def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000) -> np.ndarray:
"""Take first `max_length` seconds from the input audio"""
sample_length = int(round(sample_rate * max_length))
if len(wav) <= sample_length:
return wav
return wav[0 : sample_length]
ACCENT_MAPPING = {
"British": "English",
......@@ -603,28 +610,30 @@ def main():
desc="Filtering by labels",
)
def prepare_dataset(batch):
batch["length"] = len(batch["audio"]["array"])
batch["labels"] = preprocess_labels(batch["labels"])
return batch
raw_datasets = raw_datasets.map(
prepare_dataset,
num_proc=data_args.preprocessing_num_workers,
desc="Computing audio length",
)
# filter training data with inputs < min_input_length
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(audio):
return max_input_length > len(audio["array"]) > min_input_length
def is_audio_valid(input_length):
return input_length > min_input_length
raw_datasets = raw_datasets.filter(
is_audio_valid,
input_columns=["audio"],
input_columns=["length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# Prepare label mappings
raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)},
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
......@@ -664,9 +673,14 @@ def main():
def train_transforms(batch):
"""Apply train_transforms across a batch."""
audios = [audio["array"] for audio in batch["audio"]]
subsampled_wavs = []
for audio in batch["audio"]:
wav = deterministic_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
subsampled_wavs.append(wav)
inputs = feature_extractor(
audios, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
subsampled_wavs, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
)
output_batch = {
model_input_name: inputs.get(model_input_name),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment