Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
9518430f
Commit
9518430f
authored
Feb 26, 2024
by
sanchit-gandhi
Browse files
use normalised dataset
parent
92ad4bd8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
21 deletions
+36
-21
audio_classification_scripts/run_mms_lid.sh
audio_classification_scripts/run_mms_lid.sh
+6
-6
run_audio_classification.py
run_audio_classification.py
+30
-15
No files found.
audio_classification_scripts/run_mms_lid.sh
View file @
9518430f
...
...
@@ -2,11 +2,11 @@
python run_audio_classification.py
\
--model_name_or_path
"facebook/mms-lid-126"
\
--train_dataset_name
"stable-speech/concatenated-accent-dataset"
\
--train_dataset_name
"stable-speech/concatenated-
normalized-
accent-dataset"
\
--train_dataset_config_name
"default"
\
--train_split_name
"train"
\
--train_label_column_name
"labels"
\
--eval_dataset_name
"stable-speech/concatenated-accent-dataset"
\
--eval_dataset_name
"stable-speech/concatenated-
normalized-
accent-dataset"
\
--eval_dataset_config_name
"default"
\
--eval_split_name
"test"
\
--eval_label_column_name
"labels"
\
...
...
@@ -22,7 +22,7 @@ python run_audio_classification.py \
--min_length_seconds
5
\
--attention_mask
\
--warmup_steps
100
\
--max_steps
1
000
\
--max_steps
2
000
\
--per_device_train_batch_size
32
\
--per_device_eval_batch_size
32
\
--preprocessing_num_workers
4
\
...
...
@@ -31,8 +31,8 @@ python run_audio_classification.py \
--logging_steps
10
\
--evaluation_strategy
"steps"
\
--eval_steps
500
\
--save_strategy
"
steps
"
\
--save_steps
1
000
\
--freeze_base_model
Fals
e
\
--save_strategy
"
no
"
\
--save_steps
2
000
\
--freeze_base_model
Tru
e
\
--push_to_hub
False
\
--trust_remote_code
run_audio_classification.py
View file @
9518430f
...
...
@@ -71,6 +71,21 @@ ACCENT_MAPPING = {
"Northern irish"
:
"Irish"
,
"New zealand"
:
"Australian"
,
"Pakistani"
:
"Indian"
,
"Mainstream us english"
:
"American"
,
"Southern british english"
:
"English"
,
"Indian english"
:
"Indian"
,
"Scottish english"
:
"Scottish"
,
"Don't know"
:
"Unknown"
,
"Nigerian english"
:
"Nigerian"
,
"Kenyan english"
:
"Kenyan"
,
"Ghanain english"
:
"Ghanain"
,
"Jamaican english"
:
"Jamaican"
,
"Indonesian english"
:
"Indonesian"
,
"South african english"
:
"South african"
,
"Irish english"
:
"Irish"
,
"Latin"
:
"Latin American"
,
"European"
:
"Unknown"
,
# Too general
"Eastern european"
:
"Eastern european"
,
# TODO(SG): keep for now, but maybe remove later
}
...
...
@@ -599,17 +614,6 @@ def main():
sampling_rate
=
feature_extractor
.
sampling_rate
model_input_name
=
feature_extractor
.
model_input_names
[
0
]
# filter training data with non-valid labels
def
is_label_valid
(
label
):
return
label
!=
"Unknown"
raw_datasets
=
raw_datasets
.
filter
(
is_label_valid
,
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by labels"
,
)
def
prepare_dataset
(
batch
):
batch
[
"length"
]
=
len
(
batch
[
"audio"
][
"array"
])
batch
[
"labels"
]
=
preprocess_labels
(
batch
[
"labels"
])
...
...
@@ -634,6 +638,17 @@ def main():
desc
=
"Filtering by audio length"
,
)
# filter training data with non-valid labels
def
is_label_valid
(
label
):
return
label
!=
"Unknown"
raw_datasets
=
raw_datasets
.
filter
(
is_label_valid
,
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by labels"
,
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict
=
Counter
(
raw_datasets
[
"train"
][
"labels"
])
...
...
@@ -650,11 +665,11 @@ def main():
if
freq
<
data_args
.
filter_threshold
:
labels_to_remove
.
append
(
lab
)
# filter training data with label freq below threshold
def
is_label_valid
(
label
):
return
label
not
in
labels_to_remove
if
len
(
labels_to_remove
):
# filter training data with label freq below threshold
def
is_label_valid
(
label
):
return
label
not
in
labels_to_remove
raw_datasets
=
raw_datasets
.
filter
(
is_label_valid
,
input_columns
=
[
"labels"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment