Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
5b5167d8
Commit
5b5167d8
authored
Feb 14, 2024
by
sanchit-gandhi
Browse files
run sweep
parent
b1fb7844
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
15 deletions
+43
-15
audio_classification_scripts/run_sweep.yaml
audio_classification_scripts/run_sweep.yaml
+17
-10
run_audio_classification.py
run_audio_classification.py
+26
-5
No files found.
audio_classification_scripts/run_sweep.yaml
View file @
5b5167d8
...
...
@@ -7,10 +7,11 @@ command:
-
--do_eval
-
--trust_remote_code
-
--overwrite_output_dir
-
--ignore_mismatched_sizes
-
${args}
method
:
grid
metric
:
goal
:
m
in
imize
goal
:
m
ax
imize
name
:
eval/accuracy
parameters
:
model_name_or_path
:
...
...
@@ -40,14 +41,18 @@ parameters:
value
:
false
learning_rate
:
value
:
1e-4
lr_scheduler_type
:
value
:
constant_with_warmup
max_length_seconds
:
value
:
20
min_length_seconds
:
value
:
5
attention_mask
:
value
:
false
warmup_
ratio
:
value
:
0
.1
num_train_epoch
s
:
value
:
5
warmup_
steps
:
value
:
5
0
max_step
s
:
value
:
1000
per_device_train_batch_size
:
value
:
32
per_device_eval_batch_size
:
...
...
@@ -55,19 +60,21 @@ parameters:
preprocessing_num_workers
:
value
:
16
dataloader_num_workers
:
value
:
4
value
:
8
logging_strategy
:
value
:
steps
logging_steps
:
value
:
10
evaluation_strategy
:
value
:
epoch
value
:
steps
eval_steps
:
value
:
1000
save_strategy
:
value
:
epoch
value
:
steps
save_steps
:
value
:
1000
metric_for_best_model
:
value
:
accuracy
save_total_limit
:
value
:
3
freeze_base_model
:
values
:
-
true
...
...
run_audio_classification.py
View file @
5b5167d8
...
...
@@ -35,7 +35,7 @@ from transformers import (
HfArgumentParser
,
Trainer
,
TrainingArguments
,
set_seed
,
WhisperForAudioClassification
,
set_seed
,
)
from
transformers.models.whisper.tokenization_whisper
import
LANGUAGES
from
transformers.trainer_utils
import
get_last_checkpoint
...
...
@@ -165,7 +165,11 @@ class DataTrainingArguments:
)
max_length_seconds
:
float
=
field
(
default
=
20
,
metadata
=
{
"help"
:
"Audio clips will be randomly cut to this length during training if the value is set."
},
metadata
=
{
"help"
:
"Audio samples will be randomly cut to this length during training if the value is set."
},
)
min_length_seconds
:
float
=
field
(
default
=
5
,
metadata
=
{
"help"
:
"Audio samples less than this value will be filtered during training if the value is set."
},
)
preprocessing_num_workers
:
Optional
[
int
]
=
field
(
default
=
None
,
...
...
@@ -197,7 +201,10 @@ class ModelArguments:
default
=
None
,
metadata
=
{
"help"
:
"Name or path of preprocessor config."
}
)
freeze_feature_encoder
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
}
default
=
False
,
metadata
=
{
"help"
:
"Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
},
)
freeze_base_model
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether to freeze the base encoder of the model."
}
...
...
@@ -225,7 +232,7 @@ class ModelArguments:
},
)
ignore_mismatched_sizes
:
bool
=
field
(
default
=
Fals
e
,
default
=
Tru
e
,
metadata
=
{
"help"
:
"Will enable to load a pretrained model whose head dimensions are different."
},
)
...
...
@@ -535,6 +542,19 @@ def main():
desc
=
"Filtering by labels"
,
)
# filter training data with inputs < min_input_length
min_input_length
=
data_args
.
min_length_seconds
*
sampling_rate
def
is_audio_valid
(
audio
):
return
len
(
audio
[
"array"
])
>
min_input_length
raw_datasets
=
raw_datasets
.
filter
(
is_audio_valid
,
input_columns
=
[
"audio"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by audio length"
,
)
# Prepare label mappings
raw_datasets
=
raw_datasets
.
map
(
lambda
label
:
{
"labels"
:
preprocess_labels
(
label
)},
...
...
@@ -631,7 +651,8 @@ def main():
if
hasattr
(
model
,
"freeze_base_model"
):
# wav2vec2-style models
model
.
freeze_base_model
()
model
.
freeze_feature_encoder
()
if
hasattr
(
model
,
"freeze_feature_encoder"
):
model
.
freeze_feature_encoder
()
elif
hasattr
(
model
,
"freeze_encoder"
):
# whisper-style models
model
.
freeze_encoder
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment