Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
0d5d9970
Commit
0d5d9970
authored
Feb 21, 2024
by
sanchit-gandhi
Browse files
concat classification
parent
b7b225a4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
47 deletions
+64
-47
audio_classification_scripts/run_mms_lid.sh
audio_classification_scripts/run_mms_lid.sh
+18
-17
audio_classification_scripts/run_sweep.yaml
audio_classification_scripts/run_sweep.yaml
+18
-16
run_audio_classification.py
run_audio_classification.py
+28
-14
No files found.
audio_classification_scripts/run_mms_lid.sh
View file @
0d5d9970
...
@@ -2,36 +2,37 @@
...
@@ -2,36 +2,37 @@
python run_audio_classification.py
\
python run_audio_classification.py
\
--model_name_or_path
"facebook/mms-lid-126"
\
--model_name_or_path
"facebook/mms-lid-126"
\
--train_dataset_name
"s
anchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/
edacc"
\
--train_dataset_name
"s
table-speech/concatenat
ed
-
acc
ent-dataset
"
\
--train_dataset_config_name
"default
+en_accented+default
"
\
--train_dataset_config_name
"default"
\
--train_split_name
"train
+test+validation
"
\
--train_split_name
"train"
\
--train_label_column_name
"
accent+accent+accent
"
\
--train_label_column_name
"
labels
"
\
--eval_dataset_name
"s
anchit-gandhi/
edacc"
\
--eval_dataset_name
"s
table-speech/concatenat
ed
-
acc
ent-dataset
"
\
--eval_dataset_config_name
"default"
\
--eval_dataset_config_name
"default"
\
--eval_split_name
"test"
\
--eval_split_name
"test"
\
--eval_label_column_name
"
accent
"
\
--eval_label_column_name
"
labels
"
\
--output_dir
"./"
\
--output_dir
"./"
\
--do_train
\
--do_train
\
--do_eval
\
--do_eval
\
--overwrite_output_dir
\
--overwrite_output_dir
\
--remove_unused_columns
False
\
--remove_unused_columns
False
\
--fp16
\
--fp16
\
--fp16_full_eval
\
--learning_rate
1e-4
\
--learning_rate
1e-4
\
--max_length_seconds
20
\
--max_length_seconds
20
\
--attention_mask
False
\
--min_length_seconds
5
\
--warmup_ratio
0.1
\
--attention_mask
\
--num_train_epochs
5
\
--warmup_steps
100
\
--max_steps
1000
\
--per_device_train_batch_size
32
\
--per_device_train_batch_size
32
\
--per_device_eval_batch_size
32
\
--per_device_eval_batch_size
32
\
--preprocessing_num_workers
16
\
--preprocessing_num_workers
4
\
--dataloader_num_workers
4
\
--dataloader_num_workers
4
\
--logging_strategy
"steps"
\
--logging_strategy
"steps"
\
--logging_steps
10
\
--logging_steps
10
\
--evaluation_strategy
"epoch"
\
--evaluation_strategy
"steps"
\
--save_strategy
"epoch"
\
--eval_steps
500
\
--load_best_model_at_end
True
\
--save_strategy
"steps"
\
--metric_for_best_model
"accuracy"
\
--save_steps
1000
\
--save_total_limit
3
\
--freeze_base_model
False
\
--freeze_base_model
\
--push_to_hub
False
\
--push_to_hub
\
--trust_remote_code
--trust_remote_code
audio_classification_scripts/run_sweep.yaml
View file @
0d5d9970
...
@@ -17,25 +17,23 @@ metric:
...
@@ -17,25 +17,23 @@ metric:
name
:
eval/accuracy
name
:
eval/accuracy
parameters
:
parameters
:
model_name_or_path
:
model_name_or_path
:
values
:
value
:
facebook/mms-lid-126
-
facebook/mms-lid-126
-
openai/whisper-large-v3
train_dataset_name
:
train_dataset_name
:
value
:
s
anchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/
edacc
value
:
s
table-speech/concatenat
ed
-
acc
ent-dataset
train_dataset_config_name
:
train_dataset_config_name
:
value
:
default
+en_accented+default
value
:
default
train_split_name
:
train_split_name
:
value
:
train
+test+validation
value
:
train
train_label_column_name
:
train_label_column_name
:
value
:
accent+accent+accent
value
:
labels
eval_dataset_name
:
eval_dataset_name
:
value
:
s
anchit-gandhi/
edacc
value
:
s
table-speech/concatenat
ed
-
acc
ent-dataset
eval_dataset_config_name
:
eval_dataset_config_name
:
value
:
default
value
:
default
eval_split_name
:
eval_split_name
:
value
:
test
value
:
test
eval_label_column_name
:
eval_label_column_name
:
value
:
accent
value
:
labels
output_dir
:
output_dir
:
value
:
./
value
:
./
remove_unused_columns
:
remove_unused_columns
:
...
@@ -45,13 +43,13 @@ parameters:
...
@@ -45,13 +43,13 @@ parameters:
lr_scheduler_type
:
lr_scheduler_type
:
value
:
constant_with_warmup
value
:
constant_with_warmup
max_length_seconds
:
max_length_seconds
:
value
:
1
0
# give some data diversity for longer audio samples
value
:
2
0
# give some data diversity for longer audio samples
min_length_seconds
:
min_length_seconds
:
value
:
5
value
:
7
attention_mask
:
attention_mask
:
value
:
fals
e
value
:
tru
e
warmup_steps
:
warmup_steps
:
value
:
5
0
value
:
10
0
max_steps
:
max_steps
:
value
:
2000
value
:
2000
per_device_train_batch_size
:
per_device_train_batch_size
:
...
@@ -59,7 +57,7 @@ parameters:
...
@@ -59,7 +57,7 @@ parameters:
per_device_eval_batch_size
:
per_device_eval_batch_size
:
value
:
16
value
:
16
preprocessing_num_workers
:
preprocessing_num_workers
:
value
:
16
value
:
4
dataloader_num_workers
:
dataloader_num_workers
:
value
:
4
value
:
4
logging_strategy
:
logging_strategy
:
...
@@ -69,7 +67,7 @@ parameters:
...
@@ -69,7 +67,7 @@ parameters:
evaluation_strategy
:
evaluation_strategy
:
value
:
steps
value
:
steps
eval_steps
:
eval_steps
:
value
:
2
000
value
:
1
000
save_strategy
:
save_strategy
:
value
:
steps
value
:
steps
save_steps
:
save_steps
:
...
@@ -77,7 +75,11 @@ parameters:
...
@@ -77,7 +75,11 @@ parameters:
metric_for_best_model
:
metric_for_best_model
:
value
:
accuracy
value
:
accuracy
freeze_base_model
:
freeze_base_model
:
value
:
false
values
:
-
false
-
true
group_by_length
:
value
:
false
# TODO(SG): batch by length
push_to_hub
:
push_to_hub
:
value
:
false
value
:
false
program
:
run_audio_classification.py
program
:
run_audio_classification.py
...
...
run_audio_classification.py
View file @
0d5d9970
...
@@ -57,6 +57,13 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
...
@@ -57,6 +57,13 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
random_offset
=
randint
(
0
,
len
(
wav
)
-
sample_length
-
1
)
random_offset
=
randint
(
0
,
len
(
wav
)
-
sample_length
-
1
)
return
wav
[
random_offset
:
random_offset
+
sample_length
]
return
wav
[
random_offset
:
random_offset
+
sample_length
]
def
deterministic_subsample
(
wav
:
np
.
ndarray
,
max_length
:
float
,
sample_rate
:
int
=
16000
)
->
np
.
ndarray
:
"""Take first `max_length` seconds from the input audio"""
sample_length
=
int
(
round
(
sample_rate
*
max_length
))
if
len
(
wav
)
<=
sample_length
:
return
wav
return
wav
[
0
:
sample_length
]
ACCENT_MAPPING
=
{
ACCENT_MAPPING
=
{
"British"
:
"English"
,
"British"
:
"English"
,
...
@@ -603,28 +610,30 @@ def main():
...
@@ -603,28 +610,30 @@ def main():
desc
=
"Filtering by labels"
,
desc
=
"Filtering by labels"
,
)
)
def
prepare_dataset
(
batch
):
batch
[
"length"
]
=
len
(
batch
[
"audio"
][
"array"
])
batch
[
"labels"
]
=
preprocess_labels
(
batch
[
"labels"
])
return
batch
raw_datasets
=
raw_datasets
.
map
(
prepare_dataset
,
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Computing audio length"
,
)
# filter training data with inputs < min_input_length
# filter training data with inputs < min_input_length
max_input_length
=
data_args
.
max_length_seconds
*
sampling_rate
min_input_length
=
data_args
.
min_length_seconds
*
sampling_rate
min_input_length
=
data_args
.
min_length_seconds
*
sampling_rate
def
is_audio_valid
(
audio
):
def
is_audio_valid
(
input_length
):
return
max_
input_length
>
len
(
audio
[
"array"
])
>
min_input_length
return
input_length
>
min_input_length
raw_datasets
=
raw_datasets
.
filter
(
raw_datasets
=
raw_datasets
.
filter
(
is_audio_valid
,
is_audio_valid
,
input_columns
=
[
"
audio
"
],
input_columns
=
[
"
length
"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by audio length"
,
desc
=
"Filtering by audio length"
,
)
)
# Prepare label mappings
raw_datasets
=
raw_datasets
.
map
(
lambda
label
:
{
"labels"
:
preprocess_labels
(
label
)},
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Pre-processing labels"
,
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
# sort by freq
count_labels_dict
=
Counter
(
raw_datasets
[
"train"
][
"labels"
])
count_labels_dict
=
Counter
(
raw_datasets
[
"train"
][
"labels"
])
...
@@ -664,9 +673,14 @@ def main():
...
@@ -664,9 +673,14 @@ def main():
def
train_transforms
(
batch
):
def
train_transforms
(
batch
):
"""Apply train_transforms across a batch."""
"""Apply train_transforms across a batch."""
audios
=
[
audio
[
"array"
]
for
audio
in
batch
[
"audio"
]]
subsampled_wavs
=
[]
for
audio
in
batch
[
"audio"
]:
wav
=
deterministic_subsample
(
audio
[
"array"
],
max_length
=
data_args
.
max_length_seconds
,
sample_rate
=
feature_extractor
.
sampling_rate
)
subsampled_wavs
.
append
(
wav
)
inputs
=
feature_extractor
(
inputs
=
feature_extractor
(
audio
s
,
return_attention_mask
=
model_args
.
attention_mask
,
sampling_rate
=
sampling_rate
subsampled_wav
s
,
return_attention_mask
=
model_args
.
attention_mask
,
sampling_rate
=
sampling_rate
)
)
output_batch
=
{
output_batch
=
{
model_input_name
:
inputs
.
get
(
model_input_name
),
model_input_name
:
inputs
.
get
(
model_input_name
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment