Commit 0d5d9970 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

concat classification

parent b7b225a4
...@@ -2,36 +2,37 @@ ...@@ -2,36 +2,37 @@
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \ --train_dataset_name "stable-speech/concatenated-accent-dataset" \
--train_dataset_config_name "default+en_accented+default" \ --train_dataset_config_name "default" \
--train_split_name "train+test+validation" \ --train_split_name "train" \
--train_label_column_name "accent+accent+accent" \ --train_label_column_name "labels" \
--eval_dataset_name "sanchit-gandhi/edacc" \ --eval_dataset_name "stable-speech/concatenated-accent-dataset" \
--eval_dataset_config_name "default" \ --eval_dataset_config_name "default" \
--eval_split_name "test" \ --eval_split_name "test" \
--eval_label_column_name "accent" \ --eval_label_column_name "labels" \
--output_dir "./" \ --output_dir "./" \
--do_train \ --do_train \
--do_eval \ --do_eval \
--overwrite_output_dir \ --overwrite_output_dir \
--remove_unused_columns False \ --remove_unused_columns False \
--fp16 \ --fp16 \
--fp16_full_eval \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--max_length_seconds 20 \ --max_length_seconds 20 \
--attention_mask False \ --min_length_seconds 5 \
--warmup_ratio 0.1 \ --attention_mask \
--num_train_epochs 5 \ --warmup_steps 100 \
--max_steps 1000 \
--per_device_train_batch_size 32 \ --per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \ --per_device_eval_batch_size 32 \
--preprocessing_num_workers 16 \ --preprocessing_num_workers 4 \
--dataloader_num_workers 4 \ --dataloader_num_workers 4 \
--logging_strategy "steps" \ --logging_strategy "steps" \
--logging_steps 10 \ --logging_steps 10 \
--evaluation_strategy "epoch" \ --evaluation_strategy "steps" \
--save_strategy "epoch" \ --eval_steps 500 \
--load_best_model_at_end True \ --save_strategy "steps" \
--metric_for_best_model "accuracy" \ --save_steps 1000 \
--save_total_limit 3 \ --freeze_base_model False \
--freeze_base_model \ --push_to_hub False \
--push_to_hub \
--trust_remote_code --trust_remote_code
...@@ -17,25 +17,23 @@ metric: ...@@ -17,25 +17,23 @@ metric:
name: eval/accuracy name: eval/accuracy
parameters: parameters:
model_name_or_path: model_name_or_path:
values: value: facebook/mms-lid-126
- facebook/mms-lid-126
- openai/whisper-large-v3
train_dataset_name: train_dataset_name:
value: sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc value: stable-speech/concatenated-accent-dataset
train_dataset_config_name: train_dataset_config_name:
value: default+en_accented+default value: default
train_split_name: train_split_name:
value: train+test+validation value: train
train_label_column_name: train_label_column_name:
value: accent+accent+accent value: labels
eval_dataset_name: eval_dataset_name:
value: sanchit-gandhi/edacc value: stable-speech/concatenated-accent-dataset
eval_dataset_config_name: eval_dataset_config_name:
value: default value: default
eval_split_name: eval_split_name:
value: test value: test
eval_label_column_name: eval_label_column_name:
value: accent value: labels
output_dir: output_dir:
value: ./ value: ./
remove_unused_columns: remove_unused_columns:
...@@ -45,13 +43,13 @@ parameters: ...@@ -45,13 +43,13 @@ parameters:
lr_scheduler_type: lr_scheduler_type:
value: constant_with_warmup value: constant_with_warmup
max_length_seconds: max_length_seconds:
value: 10 # give some data diversity for longer audio samples value: 20 # give some data diversity for longer audio samples
min_length_seconds: min_length_seconds:
value: 5 value: 7
attention_mask: attention_mask:
value: false value: true
warmup_steps: warmup_steps:
value: 50 value: 100
max_steps: max_steps:
value: 2000 value: 2000
per_device_train_batch_size: per_device_train_batch_size:
...@@ -59,7 +57,7 @@ parameters: ...@@ -59,7 +57,7 @@ parameters:
per_device_eval_batch_size: per_device_eval_batch_size:
value: 16 value: 16
preprocessing_num_workers: preprocessing_num_workers:
value: 16 value: 4
dataloader_num_workers: dataloader_num_workers:
value: 4 value: 4
logging_strategy: logging_strategy:
...@@ -69,7 +67,7 @@ parameters: ...@@ -69,7 +67,7 @@ parameters:
evaluation_strategy: evaluation_strategy:
value: steps value: steps
eval_steps: eval_steps:
value: 2000 value: 1000
save_strategy: save_strategy:
value: steps value: steps
save_steps: save_steps:
...@@ -77,7 +75,11 @@ parameters: ...@@ -77,7 +75,11 @@ parameters:
metric_for_best_model: metric_for_best_model:
value: accuracy value: accuracy
freeze_base_model: freeze_base_model:
value: false values:
- false
- true
group_by_length:
value: false # TODO(SG): batch by length
push_to_hub: push_to_hub:
value: false value: false
program: run_audio_classification.py program: run_audio_classification.py
......
...@@ -57,6 +57,13 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600 ...@@ -57,6 +57,13 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
random_offset = randint(0, len(wav) - sample_length - 1) random_offset = randint(0, len(wav) - sample_length - 1)
return wav[random_offset : random_offset + sample_length] return wav[random_offset : random_offset + sample_length]
def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000) -> np.ndarray:
"""Take first `max_length` seconds from the input audio"""
sample_length = int(round(sample_rate * max_length))
if len(wav) <= sample_length:
return wav
return wav[0 : sample_length]
ACCENT_MAPPING = { ACCENT_MAPPING = {
"British": "English", "British": "English",
...@@ -603,28 +610,30 @@ def main(): ...@@ -603,28 +610,30 @@ def main():
desc="Filtering by labels", desc="Filtering by labels",
) )
def prepare_dataset(batch):
batch["length"] = len(batch["audio"]["array"])
batch["labels"] = preprocess_labels(batch["labels"])
return batch
raw_datasets = raw_datasets.map(
prepare_dataset,
num_proc=data_args.preprocessing_num_workers,
desc="Computing audio length",
)
# filter training data with inputs < min_input_length # filter training data with inputs < min_input_length
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(audio): def is_audio_valid(input_length):
return max_input_length > len(audio["array"]) > min_input_length return input_length > min_input_length
raw_datasets = raw_datasets.filter( raw_datasets = raw_datasets.filter(
is_audio_valid, is_audio_valid,
input_columns=["audio"], input_columns=["length"],
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length", desc="Filtering by audio length",
) )
# Prepare label mappings
raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)},
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered) # Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq # sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"]) count_labels_dict = Counter(raw_datasets["train"]["labels"])
...@@ -664,9 +673,14 @@ def main(): ...@@ -664,9 +673,14 @@ def main():
def train_transforms(batch): def train_transforms(batch):
"""Apply train_transforms across a batch.""" """Apply train_transforms across a batch."""
audios = [audio["array"] for audio in batch["audio"]] subsampled_wavs = []
for audio in batch["audio"]:
wav = deterministic_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
subsampled_wavs.append(wav)
inputs = feature_extractor( inputs = feature_extractor(
audios, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate subsampled_wavs, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
) )
output_batch = { output_batch = {
model_input_name: inputs.get(model_input_name), model_input_name: inputs.get(model_input_name),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment