Commit 9518430f authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

use normalised dataset

parent 92ad4bd8
......@@ -2,11 +2,11 @@
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "stable-speech/concatenated-accent-dataset" \
--train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--train_dataset_config_name "default" \
--train_split_name "train" \
--train_label_column_name "labels" \
--eval_dataset_name "stable-speech/concatenated-accent-dataset" \
--eval_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "labels" \
......@@ -22,7 +22,7 @@ python run_audio_classification.py \
--min_length_seconds 5 \
--attention_mask \
--warmup_steps 100 \
--max_steps 1000 \
--max_steps 2000 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--preprocessing_num_workers 4 \
......@@ -31,8 +31,8 @@ python run_audio_classification.py \
--logging_steps 10 \
--evaluation_strategy "steps" \
--eval_steps 500 \
--save_strategy "steps" \
--save_steps 1000 \
--freeze_base_model False \
--save_strategy "no" \
--save_steps 2000 \
--freeze_base_model True \
--push_to_hub False \
--trust_remote_code
......@@ -71,6 +71,21 @@ ACCENT_MAPPING = {
"Northern irish": "Irish",
"New zealand": "Australian",
"Pakistani": "Indian",
"Mainstream us english": "American",
"Southern british english": "English",
"Indian english": "Indian",
"Scottish english": "Scottish",
"Don't know": "Unknown",
"Nigerian english": "Nigerian",
"Kenyan english": "Kenyan",
"Ghanain english": "Ghanain",
"Jamaican english": "Jamaican",
"Indonesian english": "Indonesian",
"South african english": "South african",
"Irish english": "Irish",
"Latin": "Latin American",
"European": "Unknown", # Too general
"Eastern european": "Eastern european", # TODO(SG): keep for now, but maybe remove later
}
......@@ -599,17 +614,6 @@ def main():
sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0]
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
def prepare_dataset(batch):
batch["length"] = len(batch["audio"]["array"])
batch["labels"] = preprocess_labels(batch["labels"])
......@@ -634,6 +638,17 @@ def main():
desc="Filtering by audio length",
)
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
......@@ -650,11 +665,11 @@ def main():
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
if len(labels_to_remove):
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment