Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
7cbf4d55
Commit
7cbf4d55
authored
Feb 14, 2024
by
sanchit-gandhi
Browse files
debugging
parent
9b7b518e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
143 additions
and
76 deletions
+143
-76
audio_classification_scripts/run_mms_lid.sh
audio_classification_scripts/run_mms_lid.sh
+12
-7
audio_classification_scripts/run_wav2vec2_dummy.sh
audio_classification_scripts/run_wav2vec2_dummy.sh
+38
-0
run_audio_classification.py
run_audio_classification.py
+93
-69
No files found.
audio_classification_scripts/run_mms_lid.sh
View file @
7cbf4d55
...
...
@@ -2,10 +2,14 @@
python run_audio_classification.py
\
--model_name_or_path
"facebook/mms-lid-126"
\
--train_dataset_name
"vctk+facebook/voxpopuli"
\
--train_dataset_config_name
"default+en_accented"
\
--train_split_name
"train+test"
\
--eval_dataset_name
""
\
--train_dataset_name
"vctk+facebook/voxpopuli+sanchit-gandhi/edacc"
\
--train_dataset_config_name
"default+en_accented+default"
\
--train_split_name
"train+test+validation"
\
--train_label_column_name
"accent"
\
--eval_dataset_name
"sanchit-gandhi/edacc"
\
--eval_dataset_config_name
"default"
\
--eval_split_name
"test"
\
--eval_label_column_name
"accent"
\
--output_dir
"./"
\
--do_train
\
--do_eval
\
...
...
@@ -13,12 +17,12 @@ python run_audio_classification.py \
--remove_unused_columns
False
\
--fp16
\
--learning_rate
1e-4
\
--max_length_seconds
10
\
--min_length_seconds
5
\
--max_length_seconds
20
\
--attention_mask
False
\
--warmup_ratio
0.1
\
--num_train_epochs
5
\
--per_device_train_batch_size
32
\
--gradient_accumulation_steps
4
\
--per_device_eval_batch_size
32
\
--dataloader_num_workers
4
\
--logging_strategy
"steps"
\
...
...
@@ -29,4 +33,5 @@ python run_audio_classification.py \
--metric_for_best_model
"accuracy"
\
--save_total_limit
3
\
--seed
0
\
--push_to_hub
--push_to_hub
\
--trust_remote_code
audio_classification_scripts/run_wav2vec2_dummy.sh
0 → 100644
View file @
7cbf4d55
#!/usr/bin/env bash
python run_audio_classification.py
\
--model_name_or_path
"hf-internal-testing/tiny-random-wav2vec2"
\
--train_dataset_name
"facebook/voxpopuli"
\
--train_dataset_config_name
"en_accented"
\
--train_split_name
"test"
\
--train_label_column_name
"accent"
\
--eval_dataset_name
"facebook/voxpopuli"
\
--eval_dataset_config_name
"en_accented"
\
--eval_split_name
"test"
\
--eval_label_column_name
"accent"
\
--trust_remote_code
\
--output_dir
"./"
\
--do_train
\
--do_eval
\
--max_train_samples
100
\
--max_eval_samples
100
\
--overwrite_output_dir
\
--remove_unused_columns
False
\
--fp16
\
--learning_rate
1e-4
\
--min_length_seconds
5
\
--max_length_seconds
10
\
--attention_mask
False
\
--warmup_ratio
0.1
\
--num_train_epochs
5
\
--per_device_train_batch_size
4
\
--per_device_eval_batch_size
4
\
--dataloader_num_workers
0
\
--logging_strategy
"steps"
\
--logging_steps
10
\
--evaluation_strategy
"epoch"
\
--save_strategy
"epoch"
\
--load_best_model_at_end
True
\
--metric_for_best_model
"accuracy"
\
--save_total_limit
3
\
--seed
0
run_audio_classification.py
View file @
7cbf4d55
...
...
@@ -16,6 +16,7 @@
import
logging
import
os
import
re
import
sys
from
dataclasses
import
dataclass
,
field
from
random
import
randint
...
...
@@ -36,9 +37,9 @@ from transformers import (
TrainingArguments
,
set_seed
,
)
from
transformers.models.whisper.tokenization_whisper
import
LANGUAGES
from
transformers.trainer_utils
import
get_last_checkpoint
from
transformers.utils
import
check_min_version
from
transformers.models.whisper.tokenization_whisper
import
LANGUAGES
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -56,19 +57,20 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
return
wav
[
random_offset
:
random_offset
+
sample_length
]
def
preprocess_labels
(
label
s
:
List
[
str
]
)
->
List
[
str
]
:
def
preprocess_labels
(
label
:
str
)
->
str
:
"""Apply pre-processing formatting to the accent labels"""
processed_labels
=
[]
for
label
in
labels
:
if
"_"
in
label
:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code
=
label
.
split
(
"_"
)[
-
1
]
label
=
LANGUAGES
[
language_code
]
if
label
==
"British"
:
# 1 speaker in VCTK is labelled as British instead of English - let's normalise
label
=
"English"
processed_labels
.
append
(
label
.
capitalize
())
return
processed_labels
if
"_"
in
label
:
# voxpopuli stylises the accent as a language code (e.g. en_pl for "polish") - convert to full accent
language_code
=
label
.
split
(
"_"
)[
-
1
]
label
=
LANGUAGES
[
language_code
]
if
label
==
"British"
:
# 1 speaker in VCTK is labelled as British instead of English - let's normalise
label
=
"English"
# VCTK labels for two words are concatenated into one (NewZeleand-> New Zealand)
label
=
re
.
sub
(
r
"(\w)([A-Z])"
,
r
"\1 \2"
,
label
)
# convert Whisper language code (polish) to capitalised (Polish)
label
=
label
.
capitalize
()
return
label
@
dataclass
...
...
@@ -161,10 +163,18 @@ class DataTrainingArguments:
)
},
)
min_length_seconds
:
float
=
field
(
default
=
5
,
metadata
=
{
"help"
:
"Audio clips less than this value will be filtered during training."
},
)
max_length_seconds
:
float
=
field
(
default
=
20
,
metadata
=
{
"help"
:
"Audio clips will be randomly cut to this length during training if the value is set."
},
)
preprocessing_num_workers
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
@
dataclass
...
...
@@ -326,7 +336,7 @@ def load_multiple_datasets(
if
dataset_dict
[
"label_column_name"
]
not
in
dataset_features
:
raise
ValueError
(
f
"Label column name
{
dataset_dict
[
'
text
_column_name'
]
}
not found in dataset"
f
"Label column name
{
dataset_dict
[
'
label
_column_name'
]
}
not found in dataset"
f
" '
{
dataset_dict
[
'name'
]
}
'. Make sure to set `--label_column_name` to the"
f
" correct text column - one of
{
', '
.
join
(
dataset_features
)
}
."
)
...
...
@@ -423,12 +433,12 @@ def main():
data_args
.
train_dataset_config_name
,
splits
=
data_args
.
train_split_name
,
label_column_names
=
data_args
.
train_label_column_name
,
streaming
=
data_args
.
streaming
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
cache_dir
=
data
_args
.
dataset_
cache_dir
,
cache_dir
=
model
_args
.
cache_dir
,
token
=
True
if
model_args
.
token
else
None
,
trust_remote_code
=
data_args
.
trust_remote_code
,
trust_remote_code
=
model_args
.
trust_remote_code
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if
training_args
.
do_eval
:
...
...
@@ -449,10 +459,10 @@ def main():
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
cache_dir
=
data
_args
.
dataset_
cache_dir
,
cache_dir
=
model
_args
.
cache_dir
,
token
=
True
if
model_args
.
token
else
None
,
stre
aming
=
data
_args
.
stre
aming
,
tru
st
_
re
mote_code
=
data_args
.
tru
st
_
re
mote_code
,
tru
st
_
re
mote_code
=
model
_args
.
tru
st
_
re
mote_code
,
#
stre
aming
=data_args.stre
aming
,
)
else
:
# load multiple eval sets
...
...
@@ -463,10 +473,10 @@ def main():
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
cache_dir
=
data
_args
.
dataset_
cache_dir
,
cache_dir
=
model
_args
.
cache_dir
,
token
=
True
if
model_args
.
use_auth_token
else
None
,
stre
aming
=
data
_args
.
stre
aming
,
tru
st
_
re
mote_code
=
data_args
.
tru
st
_
re
mote_code
,
tru
st
_
re
mote_code
=
model
_args
.
tru
st
_
re
mote_code
,
#
stre
aming
=data_args.stre
aming
,
)
features
=
raw_datasets
[
pretty_name
].
features
.
keys
()
if
dataset_dict
[
"label_column_name"
]
not
in
features
:
...
...
@@ -505,36 +515,70 @@ def main():
data_args
.
audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
feature_extractor
.
sampling_rate
)
)
model_input_name
=
feature_extractor
.
model_input_names
[
0
]
if
training_args
.
do_train
:
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
(
raw_datasets
[
"train"
].
shuffle
(
seed
=
training_args
.
seed
).
select
(
range
(
data_args
.
max_train_samples
))
)
def
train_transforms
(
batch
):
"""Apply train_transforms across a batch."""
subsampled_wavs
=
[]
for
audio
in
batch
[
data_args
.
audio_column_name
]:
wav
=
random_subsample
(
audio
[
"array"
],
max_length
=
data_args
.
max_length_seconds
,
sample_rate
=
feature_extractor
.
sampling_rate
if
training_args
.
do_eval
:
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
(
raw_datasets
[
"eval"
].
shuffle
(
seed
=
training_args
.
seed
).
select
(
range
(
data_args
.
max_eval_samples
))
)
subsampled_wavs
.
append
(
wav
)
inputs
=
feature_extractor
(
subsampled_wavs
,
sampling_rate
=
feature_extractor
.
sampling_rate
)
output_batch
=
{
model_input_name
:
inputs
.
get
(
model_input_name
)}
output_batch
[
"labels"
]
=
preprocess_labels
(
batch
[
"labels"
])
return
output_batch
def
val_transforms
(
batch
):
"""Apply val_transforms across a batch."""
wavs
=
[
audio
[
"array"
]
for
audio
in
batch
[
data_args
.
audio_column_name
]]
inputs
=
feature_extractor
(
wavs
,
sampling_rate
=
feature_extractor
.
sampling_rate
)
output_batch
=
{
model_input_name
:
inputs
.
get
(
model_input_name
)}
output_batch
[
"labels"
]
=
preprocess_labels
(
batch
[
"labels"
])
return
output_batch
sampling_rate
=
feature_extractor
.
sampling_rate
model_input_name
=
feature_extractor
.
model_input_names
[
0
]
max_input_length
=
data_args
.
max_length_seconds
*
sampling_rate
min_input_length
=
data_args
.
min_length_seconds
*
sampling_rate
def
prepare_dataset
(
sample
):
audio
=
sample
[
"audio"
][
"array"
]
if
len
(
audio
)
/
sampling_rate
>
max_input_length
:
audio
=
random_subsample
(
audio
,
max_input_length
,
sampling_rate
)
inputs
=
feature_extractor
(
audio
,
sampling_rate
=
sampling_rate
)
sample
[
model_input_name
]
=
inputs
.
get
(
model_input_name
)
sample
[
"input_length"
]
=
len
(
audio
)
/
sampling_rate
sample
[
"labels"
]
=
preprocess_labels
(
sample
[
"labels"
])
return
sample
vectorized_datasets
=
raw_datasets
.
map
(
prepare_dataset
,
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Preprocess dataset"
)
# filter training data with inputs longer than max_input_length
def
is_audio_in_length_range
(
length
):
return
min_input_length
<
length
<
max_input_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_audio_in_length_range
,
input_columns
=
[
"input_length"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by audio length"
,
)
# filter training data with non valid labels
def
is_label_valid
(
label
):
return
label
!=
"Unknown"
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_label_valid
,
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by labels"
,
)
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels
=
raw
_datasets
[
"train"
][
"label"
]
label2id
,
id2label
=
{},
{}
labels
=
vectorized
_datasets
[
"train"
][
"label
s
"
]
label2id
,
id2label
,
num_label
=
{},
{},
{}
for
i
,
label
in
enumerate
(
labels
):
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
num_label
[
label
]
+=
1
if
label
not
in
label2id
:
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
logger
.
info
(
f
"Number of labels:
{
num_label
}
"
)
# Load the accuracy metric from the datasets package
metric
=
evaluate
.
load
(
"accuracy"
,
cache_dir
=
model_args
.
cache_dir
)
...
...
@@ -572,32 +616,12 @@ def main():
if
model_args
.
freeze_feature_encoder
:
model
.
freeze_feature_encoder
()
if
training_args
.
do_train
:
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
(
raw_datasets
[
"train"
].
shuffle
(
seed
=
training_args
.
seed
).
select
(
range
(
data_args
.
max_train_samples
))
)
# Set the training transforms
raw_datasets
[
"train"
].
set_transform
(
train_transforms
,
columns
=
[
model_input_name
,
"labels"
],
output_all_columns
=
False
)
if
training_args
.
do_eval
:
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
(
raw_datasets
[
"eval"
].
shuffle
(
seed
=
training_args
.
seed
).
select
(
range
(
data_args
.
max_eval_samples
))
)
# Set the validation transforms
raw_datasets
[
"eval"
].
set_transform
(
val_transforms
,
columns
=
[
model_input_name
,
"labels"
],
output_all_columns
=
False
)
# Initialize our trainer
trainer
=
Trainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
raw
_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
raw
_datasets
[
"eval"
]
if
training_args
.
do_eval
else
None
,
train_dataset
=
vectorized
_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
vectorized
_datasets
[
"eval"
]
if
training_args
.
do_eval
else
None
,
compute_metrics
=
compute_metrics
,
tokenizer
=
feature_extractor
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment