Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
9232a47b
Unverified
Commit
9232a47b
authored
May 22, 2024
by
Yoach Lacombe
Committed by
GitHub
May 22, 2024
Browse files
Merge pull request #53 from ylacombe/nits-improvements
[Training] Small nits
parents
5518cc2f
a0bc9e78
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
6 deletions
+22
-6
training/arguments.py
training/arguments.py
+7
-1
training/data.py
training/data.py
+1
-1
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+14
-4
No files found.
training/arguments.py
View file @
9232a47b
...
...
@@ -218,7 +218,7 @@ class DataTrainingArguments:
metadata
=
{
"help"
:
(
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
"Also, used to set maximum des
c
ription token length if `pad_to_max_length=True`."
)
},
)
...
...
@@ -277,6 +277,12 @@ class DataTrainingArguments:
default
=
"parler-speech"
,
metadata
=
{
"help"
:
"The name of the wandb project."
},
)
wandb_run_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"If specified, the name of the run. If not specified, wandb will give a random name to this run."
},
)
save_to_disk
:
str
=
field
(
default
=
None
,
metadata
=
{
...
...
training/data.py
View file @
9232a47b
...
...
@@ -31,7 +31,7 @@ class DataCollatorEncodecWithPadding:
audios
=
[
feature
[
self
.
audio_column_name
][
"array"
]
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
# since resampling has already been performed in the 'load_multiple_datasets' function,
# since resampling has already been performed in the 'load_multiple_datasets' function,
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
sampling_rate
=
self
.
feature_extractor
.
sampling_rate
batch
=
self
.
feature_extractor
(
...
...
training/run_parler_tts_training.py
View file @
9232a47b
...
...
@@ -98,9 +98,6 @@ def main():
####### A. Preparation
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
if
training_args
.
torch_compile
:
# TODO(YL): add more compile modes?
kwargs_handlers
.
append
(
TorchDynamoPlugin
(
backend
=
"inductor"
,
mode
=
"default"
))
# reduce-overhead
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
...
...
@@ -129,6 +126,7 @@ def main():
"adam_beta2"
:
training_args
.
adam_beta2
,
"temperature"
:
model_args
.
temperature
,
},
init_kwargs
=
{
"wandb"
:
{
"name"
:
data_args
.
wandb_run_name
}}
if
data_args
.
wandb_run_name
else
None
,
)
# Detecting last checkpoint and eventually continue from last checkpoint
...
...
@@ -538,7 +536,7 @@ def main():
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
audio_max_length
=
None
if
training_args
.
torch_compile
:
if
padding
==
"max_length"
:
audio_max_length
=
max
(
vectorized_datasets
[
"train"
][
"target_length"
])
with
accelerator
.
main_process_first
():
max_sample
=
vectorized_datasets
[
"train"
].
filter
(
...
...
@@ -548,6 +546,18 @@ def main():
)
audio_max_length
=
torch
.
tensor
(
max_sample
[
0
][
"labels"
]).
shape
[
1
]
if
training_args
.
group_by_length
:
# apply a simple heuristic to take into account audio and text lengths
def
add_target_lengths
(
target_length
,
prompt
,
description
):
return
{
"target_length"
:
target_length
+
len
(
prompt
)
+
len
(
description
)}
with
accelerator
.
main_process_first
():
vectorized_datasets
=
vectorized_datasets
.
map
(
add_target_lengths
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
,
"prompt_input_ids"
,
"input_ids"
],
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment