Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
a7231794
Commit
a7231794
authored
Mar 24, 2024
by
sanchit-gandhi
Browse files
make post-processing more efficient
parent
1176f1bb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
11 deletions
+13
-11
run_prompt_creation.py
run_prompt_creation.py
+13
-11
No files found.
run_prompt_creation.py
View file @
a7231794
...
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union
...
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Union
import
torch
import
torch
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
accelerate.logging
import
get_logger
from
datasets
import
DatasetDict
,
load_dataset
from
datasets
import
DatasetDict
,
load_dataset
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -18,7 +19,7 @@ from transformers import (
...
@@ -18,7 +19,7 @@ from transformers import (
)
)
logger
=
logging
.
get
L
ogger
(
__name__
)
logger
=
get
_l
ogger
(
__name__
,
log_level
=
"INFO"
)
@
dataclass
@
dataclass
...
@@ -223,6 +224,7 @@ For example, given the following keywords: 'female', 'slightly roomy sounding',
...
@@ -223,6 +224,7 @@ For example, given the following keywords: 'female', 'slightly roomy sounding',
For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:"
For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:"
"""
"""
def
main
():
def
main
():
# 1. Parse input arguments
# 1. Parse input arguments
parser
=
HfArgumentParser
((
ModelArguments
,
DataArguments
))
parser
=
HfArgumentParser
((
ModelArguments
,
DataArguments
))
...
@@ -235,7 +237,6 @@ def main():
...
@@ -235,7 +237,6 @@ def main():
# 2. Setup logging
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
# Make one log on every process with the configuration for debugging.
logger
.
setLevel
(
logging
.
INFO
)
logging
.
basicConfig
(
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
...
@@ -368,6 +369,12 @@ def main():
...
@@ -368,6 +369,12 @@ def main():
output_ids
=
accelerator
.
pad_across_processes
(
output_ids
,
dim
=
1
,
pad_index
=
tokenizer
.
pad_token_id
)
output_ids
=
accelerator
.
pad_across_processes
(
output_ids
,
dim
=
1
,
pad_index
=
tokenizer
.
pad_token_id
)
return
output_ids
return
output_ids
def
postprocess_dataset
(
sample
):
prompt_text
=
tokenizer
.
decode
(
sample
[
"input_ids"
],
skip_special_tokens
=
True
)
generated_text
=
tokenizer
.
decode
(
sample
[
"generated_ids"
],
skip_special_tokens
=
True
)
sample
[
"text_description"
]
=
generated_text
[
len
(
prompt_text
)
:]
return
sample
for
split
in
vectorized_datasets
:
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
data_loader
=
DataLoader
(
vectorized_datasets
[
split
],
vectorized_datasets
[
split
],
...
@@ -382,21 +389,16 @@ def main():
...
@@ -382,21 +389,16 @@ def main():
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generated_ids
=
generate_step
(
batch
)
generated_ids
=
generate_step
(
batch
)
generated_ids
=
accelerator
.
gather_for_metrics
(
generated_ids
)
generated_ids
=
accelerator
.
gather_for_metrics
(
generated_ids
)
all_generated_ids
.
extend
(
generated_ids
.
cpu
())
all_generated_ids
.
extend
(
generated_ids
.
cpu
()
.
numpy
()
)
def
postprocess_dataset
(
sample
,
idx
):
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
add_column
(
"generated_ids"
,
all_generated_ids
)
prompt_text
=
tokenizer
.
decode
(
sample
[
"input_ids"
],
skip_special_tokens
=
True
)
generated_text
=
tokenizer
.
decode
(
all_generated_ids
[
idx
],
skip_special_tokens
=
True
)
sample
[
"text_description"
]
=
generated_text
[
len
(
prompt_text
)
:]
return
sample
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
postprocess_dataset
,
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
desc
=
"Postprocessing dataset"
,
desc
=
"Postprocessing dataset"
,
remove_columns
=
[
"input_ids"
],
remove_columns
=
[
"input_ids"
,
"generated_ids"
],
with_indices
=
True
,
)
)
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment