Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
1fe3fc1e
Commit
1fe3fc1e
authored
Apr 04, 2024
by
yoach@huggingface.co
Browse files
add possibility to use metadata prompt column name instead of data
parent
cb109592
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
run_stable_speech_training.py
run_stable_speech_training.py
+6
-6
No files found.
run_stable_speech_training.py
View file @
1fe3fc1e
...
...
@@ -723,6 +723,7 @@ def load_multiple_datasets(
metadata_dataset_name
=
dataset_dict
[
"metadata_dataset_name"
]
if
metadata_dataset_name
is
not
None
:
logger
.
info
(
f
'Merging
{
dataset_dict
[
"name"
]
}
-
{
dataset_dict
[
"split"
]
}
with
{
metadata_dataset_name
}
-
{
dataset_dict
[
"split"
]
}
'
)
metadata_dataset
=
load_dataset
(
metadata_dataset_name
,
dataset_dict
[
"config"
],
...
...
@@ -732,7 +733,7 @@ def load_multiple_datasets(
)
# TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once
# To iterate faster, I bypass the original id check and do another one. - Done once
because assuming it won't change next time
# if dataset_dict["name"] == "stable-speech/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
...
...
@@ -760,7 +761,7 @@ def load_multiple_datasets(
# We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset
if
prompt_column_name
in
dataset
.
column_names
:
print
(
f
"REMOVE
{
prompt_column_name
}
from dataset
{
dataset_dict
[
'name'
]
}
- dataset_dict['split']"
)
logger
.
info
(
f
"REMOVE
{
prompt_column_name
}
from dataset
{
dataset_dict
[
'name'
]
}
- dataset_dict['split']"
)
dataset
.
remove_columns
(
prompt_column_name
)
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
...
...
@@ -771,10 +772,6 @@ def load_multiple_datasets(
if
id_column_name
is
not
None
and
dataset_dict
[
"name"
]
!=
"stable-speech/mls_eng_10k"
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
# TODO: remove
print
(
"dataset"
,
dataset
)
print
(
dataset
[
0
][
prompt_column_name
])
dataset_features
=
dataset
.
features
.
keys
()
...
...
@@ -996,6 +993,9 @@ def main():
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
prompt_column_name
=
data_args
.
prompt_column_name
,
audio_column_name
=
data_args
.
target_audio_column_name
,
sampling_rate
=
sampling_rate
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment