Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
84e0def5
"vscode:/vscode.git/clone" did not exist on "58631803e5ab484a0a083ba43d5f5507b0d70c4f"
Commit
84e0def5
authored
Mar 26, 2024
by
Yoach Lacombe
Browse files
fix encodec collator
parent
b09eba24
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
10 deletions
+5
-10
run_stable_speech_training.py
run_stable_speech_training.py
+5
-10
No files found.
run_stable_speech_training.py
View file @
84e0def5
...
@@ -495,13 +495,10 @@ class DataCollatorEncodecWithPadding:
...
@@ -495,13 +495,10 @@ class DataCollatorEncodecWithPadding:
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
# split inputs and labels since they have to be of different lengths and need
# split inputs and labels since they have to be of different lengths and need
# different padding methods
# different padding methods
audios
=
[
torch
.
tensor
(
feature
[
self
.
audio_column_name
]
).
squeeze
()
for
feature
in
features
]
audios
=
[
feature
[
self
.
audio_column_name
]
[
"array"
]
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
return_attention_mask
=
True
)
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
"longest"
)
batch
[
self
.
feature_extractor_input_name
]
=
batch
[
self
.
feature_extractor_input_name
].
unsqueeze
(
1
)
# add mono-channel
batch
[
"padding_mask"
]
=
batch
.
pop
(
"attention_mask"
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
return
batch
return
batch
...
@@ -1083,17 +1080,15 @@ def main():
...
@@ -1083,17 +1080,15 @@ def main():
# this is a trick to avoid to rewrite the entire audio column which takes ages
# this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove
=
[
col
for
col
in
next
(
iter
(
raw_datasets
.
values
())).
column_names
if
col
!=
target_audio_column_name
]
cols_to_remove
=
[
col
for
col
in
next
(
iter
(
raw_datasets
.
values
())).
column_names
if
col
!=
target_audio_column_name
]
for
split
in
raw_datasets
:
for
split
in
raw_datasets
:
vectorized_datasets
[
split
]
=
concatenate_datasets
([
raw_datasets
[
split
].
remove_columns
(
cols_to_remove
),
tmp_datasets
[
split
]],
axis
=
1
)
raw_datasets
[
split
]
=
concatenate_datasets
([
raw_datasets
[
split
].
remove_columns
(
cols_to_remove
),
tmp_datasets
[
split
]],
axis
=
1
)
# TODO: remove
logger
.
info
(
f
"Vectorized datasets
{
vectorized_datasets
}
"
)
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
def
is_audio_in_length_range
(
length
):
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
return
length
>
min_target_length
and
length
<
max_target_length
# filter data that is shorter than min_target_length
# filter data that is shorter than min_target_length
vectorized_datasets
=
vectorized
_datasets
.
filter
(
vectorized_datasets
=
raw
_datasets
.
filter
(
is_audio_in_length_range
,
is_audio_in_length_range
,
num_proc
=
num_workers
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
input_columns
=
[
"target_length"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment