Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
dbb95132
Commit
dbb95132
authored
Feb 27, 2024
by
Yoach Lacombe
Browse files
improve mapping
parent
d0140745
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
run_stable_speech_training.py
run_stable_speech_training.py
+7
-4
No files found.
run_stable_speech_training.py
View file @
dbb95132
...
@@ -1005,13 +1005,13 @@ def main():
...
@@ -1005,13 +1005,13 @@ def main():
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
def
postprocess_dataset
(
sample
,
idx
):
def
postprocess_dataset
(
input_ids
,
prompt_input_ids
,
idx
):
# (1, codebooks, seq_len)
# (1, codebooks, seq_len)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
labels
=
labels
[:,
:,
:
len_
]
labels
=
labels
[:,
:,
:(
len_
)
%
10
+
20
]
# TODO: change
#
labels = labels[:, :, :(len_)%10+20] # TODO: change
# add bos
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
...
@@ -1034,14 +1034,17 @@ def main():
...
@@ -1034,14 +1034,17 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
# we also remove the last timestampts (full of PAD)
sample
[
"labels"
]
=
labels
[:,
1
:].
cpu
()
output
=
{
"labels"
:
labels
[:,
1
:].
cpu
()}
return
sample
output
[
"input_ids"
]
=
input_ids
output
[
"prompt_input_ids"
]
=
prompt_input_ids
return
output
# TODO: done multiple times, how to deal with it.
# TODO: done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
postprocess_dataset
,
num_proc
=
num_workers
,
num_proc
=
num_workers
,
input_columns
=
[
"input_ids"
,
"prompt_input_ids"
],
desc
=
"Postprocessing labeling"
,
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
with_indices
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment