Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
b7b225a4
Commit
b7b225a4
authored
Feb 20, 2024
by
sanchit-gandhi
Browse files
dataset concat
parent
8820042c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
372 additions
and
0 deletions
+372
-0
dataset_concatenation_scripts/run_dataset_concatenation.sh
dataset_concatenation_scripts/run_dataset_concatenation.sh
+11
-0
run_dataset_concatenation.py
run_dataset_concatenation.py
+361
-0
No files found.
dataset_concatenation_scripts/run_dataset_concatenation.sh
0 → 100644
View file @
b7b225a4
#!/usr/bin/env bash
python run_dataset_concatenation.py
\
--dataset_name
"sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc"
\
--dataset_config_name
"default+en_accented+default"
\
--dataset_split_name
"train+test+validation"
\
--label_column_name
"accent+accent+accent"
\
--text_column_name
"text+normalized_text+text"
\
--speaker_column_name
"speaker_id+speaker_id+speaker"
\
--batch_size
250
\
--output_dir
"./concatenated-dataset"
run_dataset_concatenation.py
0 → 100644
View file @
b7b225a4
import
os
import
sys
from
dataclasses
import
dataclass
,
field
from
pathlib
import
Path
import
numpy
as
np
from
datasets
import
Audio
,
concatenate_datasets
,
load_dataset
from
huggingface_hub
import
get_full_repo_name
from
transformers
import
HfArgumentParser
,
WhisperTokenizerFast
@
dataclass
class
DataTrainingArguments
:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset to use (via the datasets library)."
},
)
dataset_config_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The configuration name of the dataset to use (via the datasets library)."
},
)
dataset_split_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
label_column_name
:
str
=
field
(
default
=
"labels"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the labels in the dataset. Defaults to 'label'"
},
)
text_column_name
:
str
=
field
(
default
=
"text"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the text transcriptions in the dataset. Defaults to 'text'"
},
)
speaker_column_name
:
str
=
field
(
default
=
"speaker_id"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the speaker ids in the dataset. Defaults to 'speaker_id'"
},
)
dataset_cache_dir
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to cache directory for saving and loading datasets"
},
)
preprocessing_num_workers
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
batch_size
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Number of examples per batch provided to the preprocessing function."
},
)
download_only
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to only do data download and skip pre-processing."
},
)
audio_column_name
:
str
=
field
(
default
=
"audio"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the audio data. Defaults to 'audio'"
},
)
max_duration_in_seconds
:
float
=
field
(
default
=
20.0
,
metadata
=
{
"help"
:
"Filter audio files that are longer than `max_duration_in_seconds` seconds"
},
)
sampling_rate
:
int
=
field
(
default
=
16_000
,
metadata
=
{
"help"
:
"Sampling rate at which to resample the audio data. Should be set to the same sampling rate as the target model."
},
)
max_samples
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes, truncate the number of examples in the dataset to this value if set."
},
)
output_dir
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the "
"original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'."
},
)
push_to_hub
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to push the processed dataset to the Hub."
},
)
seed
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"RNG seed for reproducibility. Used during the final shuffling of the combined dataset."
},
)
def
convert_dataset_str_to_list
(
dataset_names
,
dataset_config_names
,
splits
=
None
,
label_column_names
=
None
,
text_column_names
=
None
,
speaker_column_names
=
None
,
dataset_samples
=
None
,
default_split
=
"train"
,
):
if
isinstance
(
dataset_names
,
str
):
dataset_names
=
dataset_names
.
split
(
"+"
)
dataset_config_names
=
dataset_config_names
.
split
(
"+"
)
splits
=
splits
.
split
(
"+"
)
if
splits
is
not
None
else
None
label_column_names
=
label_column_names
.
split
(
"+"
)
if
label_column_names
is
not
None
else
None
text_column_names
=
text_column_names
.
split
(
"+"
)
if
text_column_names
is
not
None
else
None
speaker_column_names
=
speaker_column_names
.
split
(
"+"
)
if
speaker_column_names
is
not
None
else
None
dataset_samples
=
dataset_samples
.
split
(
"+"
)
if
dataset_samples
is
not
None
else
None
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
if
len
(
dataset_names
)
!=
len
(
dataset_config_names
):
raise
ValueError
(
f
"Ensure one config is passed for each dataset, got
{
len
(
dataset_names
)
}
datasets and"
f
"
{
len
(
dataset_config_names
)
}
configs."
)
if
splits
is
not
None
and
len
(
splits
)
!=
len
(
dataset_names
):
raise
ValueError
(
f
"Ensure one split is passed for each dataset, got
{
len
(
dataset_names
)
}
datasets and
{
len
(
splits
)
}
splits."
)
if
label_column_names
is
not
None
and
len
(
label_column_names
)
!=
len
(
dataset_names
):
raise
ValueError
(
f
"Ensure one label column name is passed for each dataset, got
{
len
(
dataset_names
)
}
datasets and"
f
"
{
len
(
label_column_names
)
}
label column names."
)
if
text_column_names
is
not
None
and
len
(
text_column_names
)
!=
len
(
dataset_names
):
raise
ValueError
(
f
"Ensure one text column name is passed for each dataset, got
{
len
(
dataset_names
)
}
datasets and"
f
"
{
len
(
text_column_names
)
}
text column names."
)
if
speaker_column_names
is
not
None
and
len
(
speaker_column_names
)
!=
len
(
dataset_names
):
raise
ValueError
(
f
"Ensure one text column name is passed for each dataset, got
{
len
(
dataset_names
)
}
datasets and"
f
"
{
len
(
speaker_column_names
)
}
speaker column names."
)
if
dataset_samples
is
not
None
:
if
len
(
dataset_samples
)
!=
len
(
dataset_names
):
raise
ValueError
(
f
"Ensure one sample is passed for each dataset, got
{
len
(
dataset_names
)
}
datasets and "
f
"
{
len
(
dataset_samples
)
}
samples."
)
dataset_samples
=
[
float
(
ds_sample
)
for
ds_sample
in
dataset_samples
]
else
:
dataset_samples
=
[
None
]
*
len
(
dataset_names
)
label_column_names
=
(
label_column_names
if
label_column_names
is
not
None
else
[
"labels"
for
_
in
range
(
len
(
dataset_names
))]
)
text_column_names
=
(
text_column_names
if
text_column_names
is
not
None
else
[
"text"
for
_
in
range
(
len
(
dataset_names
))]
)
speaker_column_names
=
(
speaker_column_names
if
speaker_column_names
is
not
None
else
[
"speaker_id"
for
_
in
range
(
len
(
dataset_names
))]
)
splits
=
splits
if
splits
is
not
None
else
[
default_split
for
_
in
range
(
len
(
dataset_names
))]
dataset_names_dict
=
[]
for
i
,
ds_name
in
enumerate
(
dataset_names
):
dataset_names_dict
.
append
(
{
"name"
:
ds_name
,
"config"
:
dataset_config_names
[
i
],
"split"
:
splits
[
i
],
"label_column_name"
:
label_column_names
[
i
],
"text_column_name"
:
text_column_names
[
i
],
"speaker_column_name"
:
speaker_column_names
[
i
],
"samples"
:
dataset_samples
[
i
],
}
)
return
dataset_names_dict
def
main
():
# 1. Parse input arguments
parser
=
HfArgumentParser
(
DataTrainingArguments
)
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
data_args
=
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))[
0
]
else
:
data_args
=
parser
.
parse_args_into_dataclasses
()[
0
]
dataset_names_dict
=
convert_dataset_str_to_list
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
splits
=
data_args
.
dataset_split_name
,
label_column_names
=
data_args
.
label_column_name
,
text_column_names
=
data_args
.
text_column_name
,
speaker_column_names
=
data_args
.
speaker_column_name
,
)
# load whisper tokenizer for normalisation
sampling_rate
=
data_args
.
sampling_rate
tokenizer
=
WhisperTokenizerFast
.
from_pretrained
(
"openai/whisper-tiny.en"
)
max_input_length
=
int
(
data_args
.
max_duration_in_seconds
*
sampling_rate
)
batch_size
=
data_args
.
batch_size
preprocessing_num_workers
=
data_args
.
preprocessing_num_workers
all_vectorized_datasets
=
[]
for
dataset_dict
in
dataset_names_dict
:
print
(
10
*
"="
,
dataset_dict
[
"name"
],
10
*
"="
)
raw_datasets
=
load_dataset
(
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
cache_dir
=
data_args
.
dataset_cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
)
if
data_args
.
download_only
:
continue
features
=
raw_datasets
.
column_names
if
dataset_dict
[
"label_column_name"
]
not
in
features
:
raise
ValueError
(
f
"--label_column_name
{
dataset_dict
[
'label_column_name'
]
}
not found in dataset '
{
dataset_dict
[
'name'
]
}
'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
features
)
}
."
)
elif
dataset_dict
[
"label_column_name"
]
!=
"labels"
:
raw_datasets
=
raw_datasets
.
rename_column
(
dataset_dict
[
"label_column_name"
],
"labels"
)
if
dataset_dict
[
"text_column_name"
]
not
in
features
:
raise
ValueError
(
f
"--text_column_name
{
dataset_dict
[
'text_column_name'
]
}
not found in dataset '
{
dataset_dict
[
'name'
]
}
'. "
"Make sure to set `--text_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
features
)
}
."
)
elif
dataset_dict
[
"text_column_name"
]
!=
"text"
:
raw_datasets
=
raw_datasets
.
rename_column
(
dataset_dict
[
"text_column_name"
],
"text"
)
if
dataset_dict
[
"speaker_column_name"
]
not
in
features
:
raise
ValueError
(
f
"--speaker_column_name
{
dataset_dict
[
'speaker_column_name'
]
}
not found in dataset '
{
dataset_dict
[
'name'
]
}
'. "
"Make sure to set `--speaker_column_name` to the correct speaker id column - one of "
f
"
{
', '
.
join
(
features
)
}
."
)
elif
dataset_dict
[
"speaker_column_name"
]
!=
"speaker_id"
:
raw_datasets
=
raw_datasets
.
rename_column
(
dataset_dict
[
"speaker_column_name"
],
"speaker_id"
)
raw_datasets
=
raw_datasets
.
remove_columns
(
set
(
raw_datasets
.
features
.
keys
())
-
{
"audio"
,
"labels"
,
"text"
,
"speaker_id"
}
)
if
data_args
.
max_samples
is
not
None
:
raw_datasets
=
raw_datasets
.
select
(
range
(
data_args
.
max_samples
))
raw_datasets
=
raw_datasets
.
cast_column
(
data_args
.
audio_column_name
,
Audio
(
sampling_rate
=
sampling_rate
))
raw_datasets
=
raw_datasets
.
sort
(
"speaker_id"
)
def
filter_transcriptions
(
text
):
normalized_text
=
tokenizer
.
normalize
(
text
).
strip
()
return
bool
(
normalized_text
)
and
text
.
lower
()
!=
"ignore_time_segment_in_scoring"
raw_datasets
=
raw_datasets
.
filter
(
filter_transcriptions
,
input_columns
=
[
"text"
],
desc
=
"Filtering non-speech transcriptions"
)
def
prepare_dataset
(
batch
):
audio
=
[
sample
[
"array"
]
for
sample
in
batch
[
"audio"
]]
input_lengths
=
[
len
(
sample
)
for
sample
in
audio
]
concatenated_audio
=
[]
concatenated_text
=
[]
concatenated_speaker
=
[]
concatenated_labels
=
[]
audio_sample
=
audio
[
0
]
text_sample
=
batch
[
"text"
][
0
]
label_sample
=
batch
[
"labels"
][
0
]
for
idx
in
range
(
1
,
len
(
audio
)):
prev_speaker
=
batch
[
"speaker_id"
][
idx
-
1
]
speaker
=
batch
[
"speaker_id"
][
idx
]
if
len
(
audio_sample
)
+
input_lengths
[
idx
]
<
max_input_length
:
if
speaker
==
prev_speaker
:
# we have no information about whether the segments follow on sequentially
# so we just ensure the same speaker as we concatenate across files
audio_sample
=
np
.
append
(
audio_sample
,
audio
[
idx
])
# extra spaces in the text transcription don't matter, since we only use it for the WER computation
text_sample
+=
" "
+
batch
[
"text"
][
idx
]
else
:
# segments do not follow sequentially, save the audio and start looping again
concatenated_audio
.
append
(
audio_sample
)
concatenated_text
.
append
(
text_sample
)
concatenated_labels
.
append
(
label_sample
)
concatenated_speaker
.
append
(
speaker
)
audio_sample
=
audio
[
idx
]
text_sample
=
batch
[
"text"
][
idx
]
label_sample
=
batch
[
"labels"
][
idx
]
else
:
# concatenated audio exceeds max length, save the audio and start looping again
concatenated_audio
.
append
(
audio_sample
)
concatenated_text
.
append
(
text_sample
)
concatenated_labels
.
append
(
label_sample
)
concatenated_speaker
.
append
(
speaker
)
audio_sample
=
audio
[
idx
]
text_sample
=
batch
[
"text"
][
idx
]
label_sample
=
batch
[
"labels"
][
idx
]
batch
[
"audio"
]
=
[{
"array"
:
array
,
"sampling_rate"
:
sampling_rate
}
for
array
in
concatenated_audio
]
batch
[
"text"
]
=
concatenated_text
batch
[
"labels"
]
=
concatenated_labels
batch
[
"speaker_id"
]
=
concatenated_speaker
return
batch
raw_datasets
=
raw_datasets
.
map
(
prepare_dataset
,
batched
=
True
,
batch_size
=
batch_size
,
num_proc
=
preprocessing_num_workers
,
desc
=
"Concatenating dataset..."
,
)
pretty_name
=
dataset_dict
[
"name"
].
split
(
"/"
)[
-
1
]
def
postprocess_ids
(
speaker_id
,
idx
):
formatted_idx
=
f
"
{
pretty_name
}
-
{
speaker_id
}
-
{
idx
}
"
return
{
"id"
:
formatted_idx
}
raw_datasets
=
raw_datasets
.
map
(
postprocess_ids
,
input_columns
=
[
"speaker_id"
],
with_indices
=
True
,
desc
=
"Setting sample idxs..."
,
num_proc
=
preprocessing_num_workers
,
)
print
(
f
"Final length
{
pretty_name
}
: "
,
len
(
raw_datasets
))
# Re-format transcriptions and condition on prev as numpy arrays
raw_datasets
=
raw_datasets
.
with_format
(
"np"
)
all_vectorized_datasets
.
append
(
raw_datasets
)
all_vectorized_datasets
=
concatenate_datasets
(
all_vectorized_datasets
)
dataset_features
=
all_vectorized_datasets
.
features
.
copy
()
dataset_features
[
"audio"
]
=
Audio
(
sampling_rate
=
sampling_rate
)
all_vectorized_datasets
=
all_vectorized_datasets
.
cast
(
dataset_features
,
batch_size
=
batch_size
,
writer_batch_size
=
batch_size
,
num_proc
=
preprocessing_num_workers
)
all_vectorized_datasets
=
all_vectorized_datasets
.
shuffle
(
seed
=
data_args
.
seed
)
all_vectorized_datasets
.
save_to_disk
(
data_args
.
output_dir
)
repo_name
=
get_full_repo_name
(
Path
(
data_args
.
output_dir
).
absolute
().
name
)
if
data_args
.
push_to_hub
:
all_vectorized_datasets
.
push_to_hub
(
repo_name
,
config_name
=
"train"
,
max_shard_size
=
"1GB"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment