Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
a04d6858
Unverified
Commit
a04d6858
authored
Apr 12, 2024
by
Wauplin
Browse files
Don't use deprecated Repository anymore
parent
10016fb0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
31 deletions
+24
-31
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+24
-31
No files found.
training/run_parler_tts_training.py
View file @
a04d6858
...
@@ -19,28 +19,27 @@
...
@@ -19,28 +19,27 @@
import
logging
import
logging
import
os
import
os
import
re
import
re
import
sys
import
shutil
import
shutil
import
sys
import
time
import
time
from
multiprocess
import
set_start_metho
d
from
dataclasses
import
dataclass
,
fiel
d
from
datetime
import
timedelta
from
datetime
import
timedelta
import
evaluate
from
tqdm
import
tqdm
from
pathlib
import
Path
from
pathlib
import
Path
from
dataclasses
import
dataclass
,
field
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
,
Set
import
datasets
import
datasets
import
evaluate
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
datasets
import
DatasetDict
,
load_dataset
,
Dataset
,
IterableDataset
,
interleave_datasets
,
concatenate_datasets
from
huggingface_hub
import
Repository
,
create_repo
import
transformers
import
transformers
from
accelerate
import
Accelerator
from
accelerate.utils
import
AutocastKwargs
,
InitProcessGroupKwargs
,
TorchDynamoPlugin
,
set_seed
from
accelerate.utils.memory
import
release_memory
from
datasets
import
Dataset
,
DatasetDict
,
IterableDataset
,
concatenate_datasets
,
interleave_datasets
,
load_dataset
from
huggingface_hub
import
HfApi
from
multiprocess
import
set_start_method
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
(
from
transformers
import
(
AutoFeatureExtractor
,
AutoFeatureExtractor
,
AutoModel
,
AutoModel
,
...
@@ -48,26 +47,19 @@ from transformers import (
...
@@ -48,26 +47,19 @@ from transformers import (
AutoTokenizer
,
AutoTokenizer
,
HfArgumentParser
,
HfArgumentParser
,
Seq2SeqTrainingArguments
,
Seq2SeqTrainingArguments
,
pipeline
,
)
)
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers
import
pipeline
from
transformers.optimization
import
get_scheduler
from
transformers.optimization
import
get_scheduler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.utils
import
send_example_telemetry
from
transformers.utils
import
send_example_telemetry
from
transformers
import
AutoModel
from
wandb
import
Audio
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
,
AutocastKwargs
,
InitProcessGroupKwargs
,
TorchDynamoPlugin
from
accelerate.utils.memory
import
release_memory
from
parler_tts
import
(
from
parler_tts
import
(
ParlerTTSForConditionalGeneration
,
ParlerTTSConfig
,
ParlerTTSConfig
,
ParlerTTSForConditionalGeneration
,
build_delay_pattern_mask
,
build_delay_pattern_mask
,
)
)
from
wandb
import
Audio
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1415,14 +1407,13 @@ def main():
...
@@ -1415,14 +1407,13 @@ def main():
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
if
training_args
.
push_to_hub
:
if
training_args
.
push_to_hub
:
# Retrieve of infer repo_name
api
=
HfApi
(
token
=
training_args
.
hub_token
)
# Create repo (repo_name from args or inferred)
repo_name
=
training_args
.
hub_model_id
repo_name
=
training_args
.
hub_model_id
if
repo_name
is
None
:
if
repo_name
is
None
:
repo_name
=
Path
(
training_args
.
output_dir
).
absolute
().
name
repo_name
=
Path
(
training_args
.
output_dir
).
absolute
().
name
# Create repo and retrieve repo_id
repo_id
=
api
.
create_repo
(
repo_name
,
exist_ok
=
True
).
repo_id
repo_id
=
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
training_args
.
hub_token
).
repo_id
# Clone repo locally
repo
=
Repository
(
training_args
.
output_dir
,
clone_from
=
repo_id
,
token
=
training_args
.
hub_token
)
with
open
(
os
.
path
.
join
(
training_args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
training_args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"wandb"
not
in
gitignore
:
if
"wandb"
not
in
gitignore
:
...
@@ -1624,9 +1615,11 @@ def main():
...
@@ -1624,9 +1615,11 @@ def main():
unwrapped_model
.
save_pretrained
(
training_args
.
output_dir
)
unwrapped_model
.
save_pretrained
(
training_args
.
output_dir
)
if
training_args
.
push_to_hub
:
if
training_args
.
push_to_hub
:
repo
.
push_to_hub
(
api
.
upload_folder
(
repo_id
=
repo_id
,
folder_path
=
training_args
.
output_dir
,
commit_message
=
f
"Saving train state of step
{
cur_step
}
"
,
commit_message
=
f
"Saving train state of step
{
cur_step
}
"
,
blocking
=
Fals
e
,
run_as_future
=
Tru
e
,
)
)
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment