Commit 294c162d authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

filter by freq

parent b03a236d
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import os import os
import re import re
import sys import sys
from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from random import randint from random import randint
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -183,6 +184,10 @@ class DataTrainingArguments: ...@@ -183,6 +184,10 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "The number of processes to use for the preprocessing."}, metadata={"help": "The number of processes to use for the preprocessing."},
) )
filter_threshold: Optional[float] = field(
default=1.0,
metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."},
)
@dataclass @dataclass
...@@ -571,6 +576,35 @@ def main(): ...@@ -571,6 +576,35 @@ def main():
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels", desc="Pre-processing labels",
) )
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
count_labels_dict = sorted(count_labels_dict.items(), key=lambda item: (-item[1], item[0]))
labels, frequencies = zip(*count_labels_dict)
total_labels = sum(frequencies)
labels_to_remove = []
logger.info(f"{'Accent':<15} {'Perc.':<5}")
logger.info("-" * 20)
for lab, freq in zip(labels, frequencies):
freq = 100 * freq / total_labels
logger.info(f"{lab:<15} {freq:<5}")
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
if len(labels_to_remove):
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering low freq labels",
)
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"])) set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {} label2id, id2label = {}, {}
...@@ -578,18 +612,6 @@ def main(): ...@@ -578,18 +612,6 @@ def main():
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
train_labels = raw_datasets["train"]["labels"]
num_labels = {key: 0 for key in set(train_labels)}
for label in train_labels:
num_labels[label] += 1
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
num_labels = sorted(num_labels.items(), key=lambda x: (-x[1], x[0]))
logger.info(f"{'Language':<15} {'Count':<5}")
logger.info("-" * 20)
for language, count in num_labels:
logger.info(f"{language:<15} {count:<5}")
def train_transforms(batch): def train_transforms(batch):
"""Apply train_transforms across a batch.""" """Apply train_transforms across a batch."""
subsampled_wavs = [] subsampled_wavs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment