Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
94f40c57
Commit
94f40c57
authored
Feb 14, 2024
by
sanchit-gandhi
Browse files
working audio class
parent
7cbf4d55
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
80 deletions
+84
-80
run_audio_classification.py
run_audio_classification.py
+84
-80
No files found.
run_audio_classification.py
View file @
94f40c57
...
...
@@ -136,13 +136,13 @@ class DataTrainingArguments:
metadata
=
{
"help"
:
"The name of the dataset column containing the audio data. Defaults to 'audio'"
},
)
train_label_column_name
:
str
=
field
(
default
=
"label"
,
default
=
"label
s
"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the labels in the train set. Defaults to 'label'"
},
)
eval_label_column_name
:
str
=
field
(
default
=
"label"
,
default
=
"label
s
"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the labels in the eval set. Defaults to 'label'"
},
)
max_train_samples
:
Optional
[
int
]
=
field
(
...
...
@@ -275,7 +275,7 @@ def convert_dataset_str_to_list(
dataset_samples
=
[
None
]
*
len
(
dataset_names
)
label_column_names
=
(
label_column_names
if
label_column_names
is
not
None
else
[
"label"
for
_
in
range
(
len
(
dataset_names
))]
label_column_names
if
label_column_names
is
not
None
else
[
"label
s
"
for
_
in
range
(
len
(
dataset_names
))]
)
splits
=
splits
if
splits
is
not
None
else
[
default_split
for
_
in
range
(
len
(
dataset_names
))]
...
...
@@ -300,7 +300,7 @@ def load_multiple_datasets(
label_column_names
:
Optional
[
List
]
=
None
,
stopping_strategy
:
Optional
[
str
]
=
"first_exhausted"
,
dataset_samples
:
Optional
[
Union
[
List
,
np
.
array
]]
=
None
,
streaming
:
Optional
[
bool
]
=
Tru
e
,
streaming
:
Optional
[
bool
]
=
Fals
e
,
seed
:
Optional
[
int
]
=
None
,
audio_column_name
:
Optional
[
str
]
=
"audio"
,
**
kwargs
,
...
...
@@ -342,11 +342,11 @@ def load_multiple_datasets(
)
# blanket renaming of all label columns to label
if
dataset_dict
[
"label_column_name"
]
!=
"label"
:
dataset
=
dataset
.
rename_column
(
dataset_dict
[
"label_column_name"
],
"label"
)
if
dataset_dict
[
"label_column_name"
]
!=
"label
s
"
:
dataset
=
dataset
.
rename_column
(
dataset_dict
[
"label_column_name"
],
"label
s
"
)
dataset_features
=
dataset
.
features
.
keys
()
columns_to_keep
=
{
"audio"
,
"label"
}
columns_to_keep
=
{
"audio"
,
"label
s
"
}
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
all_datasets
.
append
(
dataset
)
...
...
@@ -451,11 +451,15 @@ def main():
label_column_names
=
data_args
.
eval_label_column_name
,
)
all_eval_splits
=
[]
if
len
(
dataset_names_dict
)
==
1
:
# load a single eval set
dataset_dict
=
dataset_names_dict
[
0
]
all_eval_splits
.
append
(
"eval"
)
raw_datasets
[
"eval"
]
=
load_dataset
(
# load multiple eval sets
for
dataset_dict
in
dataset_names_dict
:
pretty_name
=
(
f
"
{
dataset_dict
[
'name'
].
split
(
'/'
)[
-
1
]
}
/
{
dataset_dict
[
'split'
].
replace
(
'.'
,
'-'
)
}
"
if
len
(
dataset_names_dict
)
>
1
else
"eval"
)
all_eval_splits
.
append
(
pretty_name
)
raw_datasets
[
pretty_name
]
=
load_dataset
(
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
...
...
@@ -464,34 +468,20 @@ def main():
trust_remote_code
=
model_args
.
trust_remote_code
,
# streaming=data_args.streaming,
)
else
:
# load multiple eval sets
for
dataset_dict
in
dataset_names_dict
:
pretty_name
=
f
"
{
dataset_dict
[
'name'
].
split
(
'/'
)[
-
1
]
}
/
{
dataset_dict
[
'split'
].
replace
(
'.'
,
'-'
)
}
"
all_eval_splits
.
append
(
pretty_name
)
raw_datasets
[
pretty_name
]
=
load_dataset
(
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
cache_dir
=
model_args
.
cache_dir
,
token
=
True
if
model_args
.
use_auth_token
else
None
,
trust_remote_code
=
model_args
.
trust_remote_code
,
# streaming=data_args.streaming,
features
=
raw_datasets
[
pretty_name
].
features
.
keys
()
if
dataset_dict
[
"label_column_name"
]
not
in
features
:
raise
ValueError
(
f
"--label_column_name
{
data_args
.
eval_label_column_name
}
not found in dataset '
{
data_args
.
dataset_name
}
'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
features
=
raw_datasets
[
pretty_name
].
features
.
keys
()
if
dataset_dict
[
"label_column_name"
]
not
in
features
:
raise
ValueError
(
f
"--label_column_name
{
data_args
.
eval_label_column_name
}
not found in dataset '
{
data_args
.
dataset_name
}
'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
elif
dataset_dict
[
"label_column_name"
]
!=
"label"
:
raw_datasets
[
pretty_name
]
=
raw_datasets
[
pretty_name
].
rename_column
(
dataset_dict
[
"label_column_name"
],
"label"
)
raw_datasets
[
pretty_name
]
=
raw_datasets
[
pretty_name
].
remove_columns
(
set
(
raw_datasets
[
pretty_name
].
features
.
keys
())
-
{
"audio"
,
"label"
}
elif
dataset_dict
[
"label_column_name"
]
!=
"labels"
:
raw_datasets
[
pretty_name
]
=
raw_datasets
[
pretty_name
].
rename_column
(
dataset_dict
[
"label_column_name"
],
"labels"
)
raw_datasets
[
pretty_name
]
=
raw_datasets
[
pretty_name
].
remove_columns
(
set
(
raw_datasets
[
pretty_name
].
features
.
keys
())
-
{
"audio"
,
"labels"
}
)
if
not
training_args
.
do_train
and
not
training_args
.
do_eval
:
raise
ValueError
(
...
...
@@ -529,56 +519,70 @@ def main():
sampling_rate
=
feature_extractor
.
sampling_rate
model_input_name
=
feature_extractor
.
model_input_names
[
0
]
max_input_length
=
data_args
.
max_length_seconds
*
sampling_rate
min_input_length
=
data_args
.
min_length_seconds
*
sampling_rate
def
prepare_dataset
(
sample
):
audio
=
sample
[
"audio"
][
"array"
]
if
len
(
audio
)
/
sampling_rate
>
max_input_length
:
audio
=
random_subsample
(
audio
,
max_input_length
,
sampling_rate
)
inputs
=
feature_extractor
(
audio
,
sampling_rate
=
sampling_rate
)
sample
[
model_input_name
]
=
inputs
.
get
(
model_input_name
)
sample
[
"input_length"
]
=
len
(
audio
)
/
sampling_rate
sample
[
"labels"
]
=
preprocess_labels
(
sample
[
"labels"
])
return
sample
vectorized_datasets
=
raw_datasets
.
map
(
prepare_dataset
,
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Preprocess dataset"
)
# filter training data with inputs longer than max_input_length
def
is_audio_in_length_range
(
length
):
return
min_input_length
<
length
<
max_input_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_audio_in_length_range
,
input_columns
=
[
"input_length"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by audio length"
,
)
# filter training data with non
valid labels
# filter training data with non
-
valid labels
def
is_label_valid
(
label
):
return
label
!=
"Unknown"
vectorized
_datasets
=
vectorized
_datasets
.
filter
(
raw
_datasets
=
raw
_datasets
.
filter
(
is_label_valid
,
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Filtering by labels"
,
)
# Prepare label mappings.
# Prepare label mappings
raw_datasets
=
raw_datasets
.
map
(
lambda
label
:
{
"labels"
:
preprocess_labels
(
label
)},
input_columns
=
[
"labels"
],
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Pre-processing labels"
,
)
# We'll include these in the model's config to get human readable labels in the Inference API.
labels
=
vectorized_datasets
[
"train"
][
"labels"
]
label2id
,
id2label
,
num_label
=
{},
{},
{}
for
i
,
label
in
enumerate
(
labels
):
num_label
[
label
]
+=
1
if
label
not
in
label2id
:
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
set_labels
=
set
(
raw_datasets
[
"train"
][
"labels"
]).
union
(
set
(
raw_datasets
[
"eval"
][
"labels"
]))
label2id
,
id2label
=
{},
{}
for
i
,
label
in
enumerate
(
set
(
set_labels
)):
label2id
[
label
]
=
str
(
i
)
id2label
[
str
(
i
)]
=
label
train_labels
=
raw_datasets
[
"train"
][
"labels"
]
num_labels
=
{
key
:
0
for
key
in
set
(
train_labels
)}
for
label
in
train_labels
:
num_labels
[
label
]
+=
1
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
num_labels
=
sorted
(
num_labels
.
items
(),
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
logger
.
info
(
f
"
{
'Language'
:
<
15
}
{
'Count'
:
<
5
}
"
)
logger
.
info
(
"-"
*
20
)
for
language
,
count
in
num_labels
:
logger
.
info
(
f
"
{
language
:
<
15
}
{
count
:
<
5
}
"
)
def
train_transforms
(
batch
):
"""Apply train_transforms across a batch."""
subsampled_wavs
=
[]
for
audio
in
batch
[
"audio"
]:
wav
=
random_subsample
(
audio
[
"array"
],
max_length
=
data_args
.
max_length_seconds
,
sample_rate
=
sampling_rate
)
subsampled_wavs
.
append
(
wav
)
inputs
=
feature_extractor
(
subsampled_wavs
,
sampling_rate
=
sampling_rate
)
output_batch
=
{
model_input_name
:
inputs
.
get
(
model_input_name
)}
output_batch
[
"labels"
]
=
[
int
(
label2id
[
label
])
for
label
in
batch
[
"labels"
]]
return
output_batch
def
val_transforms
(
batch
):
"""Apply val_transforms across a batch."""
wavs
=
[
audio
[
"array"
]
for
audio
in
batch
[
"audio"
]]
inputs
=
feature_extractor
(
wavs
,
sampling_rate
=
sampling_rate
)
output_batch
=
{
model_input_name
:
inputs
.
get
(
model_input_name
)}
output_batch
[
"labels"
]
=
[
int
(
label2id
[
label
])
for
label
in
batch
[
"labels"
]]
return
output_batch
logger
.
info
(
f
"Number of labels:
{
num_label
}
"
)
if
training_args
.
do_train
:
# Set the training transforms
raw_datasets
[
"train"
].
set_transform
(
train_transforms
,
output_all_columns
=
False
)
if
training_args
.
do_eval
:
# Set the validation transforms
raw_datasets
[
"eval"
].
set_transform
(
val_transforms
,
output_all_columns
=
False
)
# Load the accuracy metric from the datasets package
metric
=
evaluate
.
load
(
"accuracy"
,
cache_dir
=
model_args
.
cache_dir
)
...
...
@@ -592,7 +596,7 @@ def main():
config
=
AutoConfig
.
from_pretrained
(
model_args
.
config_name
or
model_args
.
model_name_or_path
,
num_labels
=
len
(
label
s
),
num_labels
=
len
(
label
2id
),
label2id
=
label2id
,
id2label
=
id2label
,
finetuning_task
=
"audio-classification"
,
...
...
@@ -620,8 +624,8 @@ def main():
trainer
=
Trainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
vectorized
_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
vectorized
_datasets
[
"eval"
]
if
training_args
.
do_eval
else
None
,
train_dataset
=
raw
_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
raw
_datasets
[
"eval"
]
if
training_args
.
do_eval
else
None
,
compute_metrics
=
compute_metrics
,
tokenizer
=
feature_extractor
,
)
...
...
@@ -649,7 +653,7 @@ def main():
kwargs
=
{
"finetuned_from"
:
model_args
.
model_name_or_path
,
"tasks"
:
"audio-classification"
,
"dataset"
:
data_args
.
dataset_name
,
"dataset"
:
data_args
.
train_
dataset_name
.
split
(
"+"
)[
0
]
,
"tags"
:
[
"audio-classification"
],
}
if
training_args
.
push_to_hub
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment