Commit 294c162d authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

filter by freq

parent b03a236d
......@@ -18,6 +18,7 @@ import logging
import os
import re
import sys
from collections import Counter
from dataclasses import dataclass, field
from random import randint
from typing import List, Optional, Union
......@@ -183,6 +184,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
filter_threshold: Optional[float] = field(
default=1.0,
metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."},
)
@dataclass
......@@ -571,6 +576,35 @@ def main():
num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
count_labels_dict = sorted(count_labels_dict.items(), key=lambda item: (-item[1], item[0]))
labels, frequencies = zip(*count_labels_dict)
total_labels = sum(frequencies)
labels_to_remove = []
logger.info(f"{'Accent':<15} {'Perc.':<5}")
logger.info("-" * 20)
for lab, freq in zip(labels, frequencies):
freq = 100 * freq / total_labels
logger.info(f"{lab:<15} {freq:<5}")
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
if len(labels_to_remove):
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering low freq labels",
)
# We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {}
......@@ -578,18 +612,6 @@ def main():
label2id[label] = str(i)
id2label[str(i)] = label
train_labels = raw_datasets["train"]["labels"]
num_labels = {key: 0 for key in set(train_labels)}
for label in train_labels:
num_labels[label] += 1
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
num_labels = sorted(num_labels.items(), key=lambda x: (-x[1], x[0]))
logger.info(f"{'Language':<15} {'Count':<5}")
logger.info("-" * 20)
for language, count in num_labels:
logger.info(f"{language:<15} {count:<5}")
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment