Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
294c162d
Commit
294c162d
authored
Feb 16, 2024
by
sanchit-gandhi
Browse files
filter by freq
parent
b03a236d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
12 deletions
+34
-12
run_audio_classification.py
run_audio_classification.py
+34
-12
No files found.
run_audio_classification.py
View file @
294c162d
...
...
@@ -18,6 +18,7 @@ import logging
import
os
import
re
import
sys
from
collections
import
Counter
from
dataclasses
import
dataclass
,
field
from
random
import
randint
from
typing
import
List
,
Optional
,
Union
...
...
@@ -183,6 +184,10 @@ class DataTrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
filter_threshold
:
Optional
[
float
]
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"Filter labels that occur less than `filter_threshold` percent in the training/eval data."
},
)
@
dataclass
...
...
@@ -571,6 +576,35 @@ def main():
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Pre-processing labels"
,
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict
=
Counter
(
raw_datasets
[
"train"
][
"labels"
])
count_labels_dict
=
sorted
(
count_labels_dict
.
items
(),
key
=
lambda
item
:
(
-
item
[
1
],
item
[
0
]))
labels
,
frequencies
=
zip
(
*
count_labels_dict
)
total_labels
=
sum
(
frequencies
)
labels_to_remove
=
[]
logger
.
info
(
f
"
{
'Accent'
:
<
15
}
{
'Perc.'
:
<
5
}
"
)
logger
.
info
(
"-"
*
20
)
for
lab
,
freq
in
zip
(
labels
,
frequencies
):
freq
=
100
*
freq
/
total_labels
logger
.
info
(
f
"
{
lab
:
<
15
}
{
freq
:
<
5
}
"
)
if
freq
<
data_args
.
filter_threshold
:
labels_to_remove
.
append
(
lab
)
# filter training data with label freq below threshold
def
is_label_valid
(
label
):
return
label
not
in
labels_to_remove
if
len
(
labels_to_remove
):
raw_datasets
=
raw_datasets
.
filter
(
is_label_valid
,
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering low freq labels"
,
)
# We'll include these in the model's config to get human readable labels in the Inference API.
set_labels
=
set
(
raw_datasets
[
"train"
][
"labels"
]).
union
(
set
(
raw_datasets
[
"eval"
][
"labels"
]))
label2id
,
id2label
=
{},
{}
...
...
@@ -578,18 +612,6 @@ def main():
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
train_labels
=
raw_datasets
[
"train"
][
"labels"
]
num_labels
=
{
key
:
0
for
key
in
set
(
train_labels
)}
for
label
in
train_labels
:
num_labels
[
label
]
+=
1
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
num_labels
=
sorted
(
num_labels
.
items
(),
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
logger
.
info
(
f
"
{
'Language'
:
<
15
}
{
'Count'
:
<
5
}
"
)
logger
.
info
(
"-"
*
20
)
for
language
,
count
in
num_labels
:
logger
.
info
(
f
"
{
language
:
<
15
}
{
count
:
<
5
}
"
)
def
train_transforms
(
batch
):
"""Apply train_transforms across a batch."""
subsampled_wavs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment