Commit 9518430f authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

use normalised dataset

parent 92ad4bd8
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "stable-speech/concatenated-accent-dataset" \ --train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--train_dataset_config_name "default" \ --train_dataset_config_name "default" \
--train_split_name "train" \ --train_split_name "train" \
--train_label_column_name "labels" \ --train_label_column_name "labels" \
--eval_dataset_name "stable-speech/concatenated-accent-dataset" \ --eval_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--eval_dataset_config_name "default" \ --eval_dataset_config_name "default" \
--eval_split_name "test" \ --eval_split_name "test" \
--eval_label_column_name "labels" \ --eval_label_column_name "labels" \
...@@ -22,7 +22,7 @@ python run_audio_classification.py \ ...@@ -22,7 +22,7 @@ python run_audio_classification.py \
--min_length_seconds 5 \ --min_length_seconds 5 \
--attention_mask \ --attention_mask \
--warmup_steps 100 \ --warmup_steps 100 \
--max_steps 1000 \ --max_steps 2000 \
--per_device_train_batch_size 32 \ --per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \ --per_device_eval_batch_size 32 \
--preprocessing_num_workers 4 \ --preprocessing_num_workers 4 \
...@@ -31,8 +31,8 @@ python run_audio_classification.py \ ...@@ -31,8 +31,8 @@ python run_audio_classification.py \
--logging_steps 10 \ --logging_steps 10 \
--evaluation_strategy "steps" \ --evaluation_strategy "steps" \
--eval_steps 500 \ --eval_steps 500 \
--save_strategy "steps" \ --save_strategy "no" \
--save_steps 1000 \ --save_steps 2000 \
--freeze_base_model False \ --freeze_base_model True \
--push_to_hub False \ --push_to_hub False \
--trust_remote_code --trust_remote_code
...@@ -71,6 +71,21 @@ ACCENT_MAPPING = { ...@@ -71,6 +71,21 @@ ACCENT_MAPPING = {
"Northern irish": "Irish", "Northern irish": "Irish",
"New zealand": "Australian", "New zealand": "Australian",
"Pakistani": "Indian", "Pakistani": "Indian",
"Mainstream us english": "American",
"Southern british english": "English",
"Indian english": "Indian",
"Scottish english": "Scottish",
"Don't know": "Unknown",
"Nigerian english": "Nigerian",
"Kenyan english": "Kenyan",
"Ghanain english": "Ghanain",
"Jamaican english": "Jamaican",
"Indonesian english": "Indonesian",
"South african english": "South african",
"Irish english": "Irish",
"Latin": "Latin American",
"European": "Unknown", # Too general
"Eastern european": "Eastern european", # TODO(SG): keep for now, but maybe remove later
} }
...@@ -599,17 +614,6 @@ def main(): ...@@ -599,17 +614,6 @@ def main():
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0] model_input_name = feature_extractor.model_input_names[0]
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
def prepare_dataset(batch): def prepare_dataset(batch):
batch["length"] = len(batch["audio"]["array"]) batch["length"] = len(batch["audio"]["array"])
batch["labels"] = preprocess_labels(batch["labels"]) batch["labels"] = preprocess_labels(batch["labels"])
...@@ -634,6 +638,17 @@ def main(): ...@@ -634,6 +638,17 @@ def main():
desc="Filtering by audio length", desc="Filtering by audio length",
) )
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered) # Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq # sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"]) count_labels_dict = Counter(raw_datasets["train"]["labels"])
...@@ -650,11 +665,11 @@ def main(): ...@@ -650,11 +665,11 @@ def main():
if freq < data_args.filter_threshold: if freq < data_args.filter_threshold:
labels_to_remove.append(lab) labels_to_remove.append(lab)
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
if len(labels_to_remove): if len(labels_to_remove):
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
raw_datasets = raw_datasets.filter( raw_datasets = raw_datasets.filter(
is_label_valid, is_label_valid,
input_columns=["labels"], input_columns=["labels"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment