Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
b1fb7844
Commit
b1fb7844
authored
Feb 14, 2024
by
sanchit-gandhi
Browse files
sweep over models/freeze
parent
83953064
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
90 additions
and
6 deletions
+90
-6
audio_classification_scripts/run_mms_lid.sh
audio_classification_scripts/run_mms_lid.sh
+1
-1
audio_classification_scripts/run_sweep.yaml
audio_classification_scripts/run_sweep.yaml
+78
-0
run_audio_classification.py
run_audio_classification.py
+11
-5
No files found.
audio_classification_scripts/run_mms_lid.sh
View file @
b1fb7844
...
...
@@ -3,7 +3,7 @@
python run_audio_classification.py
\
--model_name_or_path
"facebook/mms-lid-126"
\
--train_dataset_name
"sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc"
\
--train_dataset_config_name
"
main
+en_accented+default"
\
--train_dataset_config_name
"
default
+en_accented+default"
\
--train_split_name
"train+test+validation"
\
--train_label_column_name
"accent+accent+accent"
\
--eval_dataset_name
"sanchit-gandhi/edacc"
\
...
...
audio_classification_scripts/run_sweep.yaml
0 → 100644
View file @
b1fb7844
command
:
-
python3
-
${program}
-
--load_best_model_at_end
-
--fp16
-
--do_train
-
--do_eval
-
--trust_remote_code
-
--overwrite_output_dir
-
${args}
method
:
grid
metric
:
goal
:
minimize
name
:
eval/accuracy
parameters
:
model_name_or_path
:
values
:
-
facebook/mms-lid-126
-
openai/whisper-large-v3
-
facebook/w2v-bert-2.0
train_dataset_name
:
value
:
sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc
train_dataset_config_name
:
value
:
default+en_accented+default
train_split_name
:
value
:
train+test+validation
train_label_column_name
:
value
:
accent+accent+accent
eval_dataset_name
:
value
:
sanchit-gandhi/edacc
eval_dataset_config_name
:
value
:
default
eval_split_name
:
value
:
test
eval_label_column_name
:
value
:
accent
output_dir
:
value
:
./
remove_unused_columns
:
value
:
false
learning_rate
:
value
:
1e-4
max_length_seconds
:
value
:
20
attention_mask
:
value
:
false
warmup_ratio
:
value
:
0.1
num_train_epochs
:
value
:
5
per_device_train_batch_size
:
value
:
32
per_device_eval_batch_size
:
value
:
32
preprocessing_num_workers
:
value
:
16
dataloader_num_workers
:
value
:
4
logging_strategy
:
value
:
steps
logging_steps
:
value
:
10
evaluation_strategy
:
value
:
epoch
save_strategy
:
value
:
epoch
metric_for_best_model
:
value
:
accuracy
save_total_limit
:
value
:
3
freeze_base_model
:
values
:
-
true
-
false
push_to_hub
:
value
:
false
program
:
run_audio_classification.py
project
:
mms-lid-accent-classification
\ No newline at end of file
run_audio_classification.py
View file @
b1fb7844
...
...
@@ -197,7 +197,7 @@ class ModelArguments:
default
=
None
,
metadata
=
{
"help"
:
"Name or path of preprocessor config."
}
)
freeze_feature_encoder
:
bool
=
field
(
default
=
Tru
e
,
metadata
=
{
"help"
:
"Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
}
default
=
Fals
e
,
metadata
=
{
"help"
:
"Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
}
)
freeze_base_model
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether to freeze the base encoder of the model."
}
...
...
@@ -297,6 +297,7 @@ def load_multiple_datasets(
dataset_config_names
:
Union
[
List
,
str
],
splits
:
Optional
[
Union
[
List
,
str
]]
=
None
,
label_column_names
:
Optional
[
List
]
=
None
,
sampling_rate
:
Optional
[
int
]
=
16000
,
stopping_strategy
:
Optional
[
str
]
=
"first_exhausted"
,
dataset_samples
:
Optional
[
Union
[
List
,
np
.
array
]]
=
None
,
streaming
:
Optional
[
bool
]
=
False
,
...
...
@@ -332,6 +333,8 @@ def load_multiple_datasets(
f
" '
{
dataset_dict
[
'name'
]
}
'. Make sure to set `--audio_column_name` to"
f
" the correct audio column - one of
{
', '
.
join
(
dataset_features
)
}
."
)
# resample to specified sampling rate
dataset
=
dataset
.
cast_column
(
"audio"
,
datasets
.
features
.
Audio
(
sampling_rate
))
if
dataset_dict
[
"label_column_name"
]
not
in
dataset_features
:
raise
ValueError
(
...
...
@@ -617,16 +620,19 @@ def main():
ignore_mismatched_sizes
=
model_args
.
ignore_mismatched_sizes
,
)
# freeze the convolutional waveform encoder
# freeze the convolutional waveform encoder
for wav2vec2-style models
if
model_args
.
freeze_feature_encoder
:
model
.
freeze_feature_encoder
()
if
hasattr
(
model
,
"freeze_feature_encoder"
):
model
.
freeze_feature_encoder
()
else
:
raise
ValueError
(
"Method for freezing the feature encoder is not defined for Whisper-style models."
)
if
model_args
.
freeze_base_model
:
if
model
.
hasattr
(
"freeze_base_model"
):
if
hasattr
(
model
,
"freeze_base_model"
):
# wav2vec2-style models
model
.
freeze_base_model
()
model
.
freeze_feature_encoder
()
elif
model
.
hasattr
(
"freeze_encoder"
):
elif
hasattr
(
model
,
"freeze_encoder"
):
# whisper-style models
model
.
freeze_encoder
()
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment