Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
cb109592
Commit
cb109592
authored
Apr 04, 2024
by
Yoach Lacombe
Browse files
add prompt column name handling
parent
9ef35aa6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
0 deletions
+15
-0
run_stable_speech_training.py
run_stable_speech_training.py
+15
-0
No files found.
run_stable_speech_training.py
View file @
cb109592
...
@@ -687,6 +687,7 @@ def load_multiple_datasets(
...
@@ -687,6 +687,7 @@ def load_multiple_datasets(
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
id_column_name
:
Optional
[
str
]
=
None
,
id_column_name
:
Optional
[
str
]
=
None
,
columns_to_keep
:
Optional
[
Set
[
str
]]
=
None
,
columns_to_keep
:
Optional
[
Set
[
str
]]
=
None
,
prompt_column_name
:
Optional
[
str
]
=
None
,
sampling_rate
:
Optional
[
int
]
=
None
,
sampling_rate
:
Optional
[
int
]
=
None
,
audio_column_name
:
Optional
[
str
]
=
None
,
audio_column_name
:
Optional
[
str
]
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -753,15 +754,28 @@ def load_multiple_datasets(
...
@@ -753,15 +754,28 @@ def load_multiple_datasets(
elif
id_column_name
is
not
None
:
elif
id_column_name
is
not
None
:
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
if
prompt_column_name
is
not
None
:
# We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset
if
prompt_column_name
in
dataset
.
column_names
:
print
(
f
"REMOVE
{
prompt_column_name
}
from dataset
{
dataset_dict
[
'name'
]
}
- dataset_dict['split']"
)
dataset
.
remove_columns
(
prompt_column_name
)
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_remove
)
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_remove
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
if
id_column_name
is
not
None
and
dataset_dict
[
"name"
]
!=
"stable-speech/mls_eng_10k"
:
if
id_column_name
is
not
None
and
dataset_dict
[
"name"
]
!=
"stable-speech/mls_eng_10k"
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
# TODO: remove
print
(
"dataset"
,
dataset
)
print
(
dataset
[
0
][
prompt_column_name
])
dataset_features
=
dataset
.
features
.
keys
()
dataset_features
=
dataset
.
features
.
keys
()
if
columns_to_keep
is
not
None
:
if
columns_to_keep
is
not
None
:
...
@@ -954,6 +968,7 @@ def main():
...
@@ -954,6 +968,7 @@ def main():
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
columns_to_keep
=
columns_to_keep
.
values
(),
prompt_column_name
=
data_args
.
prompt_column_name
,
audio_column_name
=
data_args
.
target_audio_column_name
,
audio_column_name
=
data_args
.
target_audio_column_name
,
sampling_rate
=
sampling_rate
,
sampling_rate
=
sampling_rate
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment