Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
82cbc3ad
Commit
82cbc3ad
authored
Apr 05, 2024
by
Yoach Lacombe
Browse files
add torch compile compatibility + remove precompute_text_hidden_states
parent
1fe3fc1e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
125 additions
and
120 deletions
+125
-120
run_stable_speech_training.py
run_stable_speech_training.py
+125
-120
No files found.
run_stable_speech_training.py
View file @
82cbc3ad
...
@@ -70,7 +70,7 @@ AutoModel.register(DACConfig, DACModel)
...
@@ -70,7 +70,7 @@ AutoModel.register(DACConfig, DACModel)
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
,
AutocastKwargs
,
InitProcessGroupKwargs
from
accelerate.utils
import
set_seed
,
AutocastKwargs
,
InitProcessGroupKwargs
,
TorchDynamoPlugin
from
accelerate.utils.memory
import
release_memory
from
accelerate.utils.memory
import
release_memory
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
...
@@ -247,17 +247,13 @@ class ModelArguments:
...
@@ -247,17 +247,13 @@ class ModelArguments:
metadata
=
{
"help"
:
"Temperature if sampling."
},
metadata
=
{
"help"
:
"Temperature if sampling."
},
)
)
max_length
:
int
=
field
(
max_length
:
int
=
field
(
default
=
1500
,
# TODO
default
=
2580
,
metadata
=
{
"help"
:
"
Whether to do sampling or greedy decoding
."
},
metadata
=
{
"help"
:
"
Generation max length
."
},
)
)
bandwidth
:
float
=
field
(
bandwidth
:
float
=
field
(
default
=
6
,
# TODO
default
=
6
,
# TODO
metadata
=
{
"help"
:
"Audio encoder bandwidth."
},
metadata
=
{
"help"
:
"Audio encoder bandwidth."
},
)
)
precompute_text_hidden_states
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to precompute text hidden states. Only work when the text encoder is freezed"
},
)
...
@@ -374,8 +370,8 @@ class DataTrainingArguments:
...
@@ -374,8 +370,8 @@ class DataTrainingArguments:
default
=
35.0
,
default
=
35.0
,
metadata
=
{
metadata
=
{
"help"
:
(
"help"
:
(
"Filter audio files that are longer than `max_duration_in_seconds` seconds to"
"Filter audio files that are longer than `max_duration_in_seconds` seconds to
'max_duration_in_seconds`.
"
"
'max_duration_in_seconds`
"
"
Also, used to set maximum audio length if `pad_to_max_length=True`.
"
)
)
},
},
)
)
...
@@ -383,7 +379,31 @@ class DataTrainingArguments:
...
@@ -383,7 +379,31 @@ class DataTrainingArguments:
default
=
0.0
,
metadata
=
{
"help"
:
"Filter audio files that are shorter than `min_duration_in_seconds` seconds"
}
default
=
0.0
,
metadata
=
{
"help"
:
"Filter audio files that are shorter than `min_duration_in_seconds` seconds"
}
)
)
max_text_length
:
int
=
field
(
max_text_length
:
int
=
field
(
default
=
500
,
metadata
=
{
"help"
:
"Max description lengths in number of characters."
}
default
=
500
,
metadata
=
{
"help"
:
"If set, max description lengths in number of characters."
}
)
max_prompt_token_length
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
)
}
)
max_description_token_length
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
)
}
)
pad_to_max_length
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
)
}
)
)
preprocessing_only
:
bool
=
field
(
preprocessing_only
:
bool
=
field
(
default
=
False
,
default
=
False
,
...
@@ -490,13 +510,15 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
...
@@ -490,13 +510,15 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
@
dataclass
@
dataclass
class
DataCollatorEncodecWithPadding
:
class
DataCollatorEncodecWithPadding
:
"""
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
to `max_length` if `max_length` is set and `padding=max_length`.
"""
"""
feature_extractor
:
AutoFeatureExtractor
feature_extractor
:
AutoFeatureExtractor
audio_column_name
:
str
audio_column_name
:
str
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
max_length
:
Optional
[
int
]
=
None
max_length
:
Optional
[
int
]
=
None
padding
:
Optional
[
str
]
=
"longest"
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
...
@@ -504,10 +526,8 @@ class DataCollatorEncodecWithPadding:
...
@@ -504,10 +526,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods
# different padding methods
audios
=
[
feature
[
self
.
audio_column_name
][
"array"
]
for
feature
in
features
]
audios
=
[
feature
[
self
.
audio_column_name
][
"array"
]
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
if
self
.
max_length
is
not
None
:
audios
=
[
audio
[:
min
(
len
(
audio
),
self
.
max_length
+
10
)]
for
audio
in
audios
]
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
max_length
=
self
.
max_length
)
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
"longest"
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
return
batch
return
batch
...
@@ -544,6 +564,9 @@ class DataCollatorStableSpeechWithPadding:
...
@@ -544,6 +564,9 @@ class DataCollatorStableSpeechWithPadding:
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
padding
:
Union
[
bool
,
str
]
=
"longest"
padding
:
Union
[
bool
,
str
]
=
"longest"
pad_to_multiple_of
:
Optional
[
int
]
=
None
pad_to_multiple_of
:
Optional
[
int
]
=
None
prompt_max_length
:
Optional
[
int
]
=
None
description_max_length
:
Optional
[
int
]
=
None
audio_max_length
:
Optional
[
int
]
=
None
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
# split inputs and labels since they have to be of different lengths and need
# split inputs and labels since they have to be of different lengths and need
...
@@ -553,28 +576,29 @@ class DataCollatorStableSpeechWithPadding:
...
@@ -553,28 +576,29 @@ class DataCollatorStableSpeechWithPadding:
labels
=
[
torch
.
tensor
(
feature
[
"labels"
]).
transpose
(
0
,
1
)
for
feature
in
features
]
labels
=
[
torch
.
tensor
(
feature
[
"labels"
]).
transpose
(
0
,
1
)
for
feature
in
features
]
# (bsz, seq_len, num_codebooks)
# (bsz, seq_len, num_codebooks)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=-
100
)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=-
100
)
if
self
.
audio_max_length
is
not
None
and
self
.
padding
==
"max_length"
:
labels
=
torch
.
nn
.
functional
.
pad
(
labels
,
pad
=
(
0
,
0
,
0
,
max
(
self
.
audio_max_length
-
labels
.
shape
[
1
],
0
)))
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
if
"encoder_outputs"
in
features
[
0
]:
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
)
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
,
max_length
=
self
.
description_max_length
)
encoder_hidden_states
=
[
torch
.
tensor
(
feature
[
"encoder_outputs"
][
"last_hidden_state"
])
for
feature
in
features
]
batch
=
{
"labels"
:
labels
,
**
input_ids
}
encoder_hidden_states
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
encoder_hidden_states
,
batch_first
=
True
,
padding_value
=
0.
)
if
self
.
audio_max_length
is
not
None
and
self
.
padding
==
"max_length"
:
batch
=
{
"labels"
:
labels
,
"encoder_outputs"
:
BaseModelOutput
(
last_hidden_state
=
encoder_hidden_states
),
**
input_ids
}
# if we do torch.compile, we need to also specify the attention_mask
else
:
decoder_attention_mask
=
torch
.
ones
(
labels
.
shape
[:
2
],
dtype
=
input_ids
[
"attention_mask"
].
dtype
)
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
[
"decoder_attention_mask"
]
=
decoder_attention_mask
batch
=
{
"labels"
:
labels
,
**
input_ids
}
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
,
max_length
=
self
.
prompt_max_length
)
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
if
"attention_mask"
in
prompt_input_ids
:
if
"attention_mask"
in
prompt_input_ids
:
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
if
self
.
feature_extractor_input_name
in
features
[
0
]:
if
self
.
feature_extractor_input_name
in
features
[
0
]:
# TODO: verify that it works
# TODO
(YL)
: verify that it works
- IMPORTANT -> probably not working
input_values
=
[{
self
.
feature_extractor_input_name
:
feature
[
self
.
feature_extractor_input_name
]}
for
feature
in
features
]
input_values
=
[{
self
.
feature_extractor_input_name
:
feature
[
self
.
feature_extractor_input_name
]}
for
feature
in
features
]
input_values
=
self
.
feature_extractor
.
pad
(
input_values
,
return_tensors
=
"pt"
)
input_values
=
self
.
feature_extractor
.
pad
(
input_values
,
return_tensors
=
"pt"
)
...
@@ -582,40 +606,6 @@ class DataCollatorStableSpeechWithPadding:
...
@@ -582,40 +606,6 @@ class DataCollatorStableSpeechWithPadding:
return
batch
return
batch
@
dataclass
class
T5TextCollatorStableSpeechWithPadding
:
"""
Data collator that will dynamically pad the inputs received.
Args:
description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
description_tokenizer
:
AutoTokenizer
padding
:
Union
[
bool
,
str
]
=
"longest"
pad_to_multiple_of
:
Optional
[
int
]
=
None
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids_len
=
[
len
(
feature
[
"input_ids"
])
for
feature
in
features
]
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"len_input_ids"
:
torch
.
tensor
(
input_ids_len
),
**
input_ids
}
return
batch
def
convert_dataset_str_to_list
(
def
convert_dataset_str_to_list
(
dataset_names
,
dataset_names
,
dataset_config_names
,
dataset_config_names
,
...
@@ -821,15 +811,24 @@ def main():
...
@@ -821,15 +811,24 @@ def main():
mixed_precision
=
"bf16"
mixed_precision
=
"bf16"
else
:
else
:
mixed_precision
=
"no"
mixed_precision
=
"no"
if
data_args
.
pad_to_max_length
and
(
data_args
.
max_duration_in_seconds
is
None
or
data_args
.
max_prompt_token_length
is
None
or
data_args
.
max_description_token_length
is
None
):
raise
ValueError
(
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
)
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
"longest"
####### A. Preparation
####### A. Preparation
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
if
training_args
.
torch_compile
:
# TODO(YL): add more compile modes?
kwargs_handlers
.
append
(
TorchDynamoPlugin
(
backend
=
"inductor"
))
accelerator
=
Accelerator
(
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
mixed_precision
=
mixed_precision
,
mixed_precision
=
mixed_precision
,
log_with
=
training_args
.
report_to
,
log_with
=
training_args
.
report_to
,
project_dir
=
training_args
.
output_dir
,
project_dir
=
training_args
.
output_dir
,
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
,
kwargs_handlers
=
kwargs_handlers
,
)
)
accelerator
.
init_trackers
(
project_name
=
data_args
.
wandb_project
,
config
=
{
accelerator
.
init_trackers
(
project_name
=
data_args
.
wandb_project
,
config
=
{
...
@@ -1064,7 +1063,7 @@ def main():
...
@@ -1064,7 +1063,7 @@ def main():
if
not
dataset_was_precomputed
:
if
not
dataset_was_precomputed
:
# Filter on text length
# Filter on text length
if
description_column_name
is
not
None
:
if
description_column_name
is
not
None
and
data_args
.
max_text_length
is
not
None
:
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
# filter description that is shorter than max_text_length
# filter description that is shorter than max_text_length
raw_datasets
=
raw_datasets
.
filter
(
raw_datasets
=
raw_datasets
.
filter
(
...
@@ -1098,64 +1097,19 @@ def main():
...
@@ -1098,64 +1097,19 @@ def main():
# T5 doesn't support fp16
# T5 doesn't support fp16
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
####### B. (Optional) Encode text if text encoder is freezed
if
model_args
.
freeze_text_encoder
and
model_args
.
precompute_text_hidden_states
:
text_data_collator
=
T5TextCollatorStableSpeechWithPadding
(
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
)
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
vectorized_datasets
[
split
],
batch_size
=
training_args
.
text_encode_per_device_eval_batch_size
,
collate_fn
=
text_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
)
data_loader
=
accelerator
.
prepare
(
data_loader
)
all_encoder_outputs
=
[]
all_encoder_lengths
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
model
.
text_encoder
.
to
(
batch
[
"input_ids"
].
device
)
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
with
torch
.
no_grad
():
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
])
encoder_outputs
=
accelerator
.
pad_across_processes
(
encoder_outputs
,
dim
=
1
,
pad_index
=
prompt_tokenizer
.
pad_token_id
)
encoder_outputs
=
accelerator
.
gather_for_metrics
(
encoder_outputs
)
lengths
=
accelerator
.
gather_for_metrics
(
batch
[
"len_input_ids"
])
if
accelerator
.
is_main_process
:
all_encoder_outputs
.
extend
(
encoder_outputs
.
last_hidden_state
.
to
(
"cpu"
))
all_encoder_lengths
.
extend
(
lengths
.
to
(
"cpu"
))
def
postprocess_dataset
(
input_ids
,
idx
):
output
=
{
"encoder_outputs"
:
BaseModelOutput
(
last_hidden_state
=
all_encoder_outputs
[
idx
][:
all_encoder_lengths
[
idx
]])}
return
output
# TODO(YL): done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
1
,
# this one is resource consuming if many processor.
input_columns
=
[
"input_ids"
],
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
writer_batch_size
=
100
,
)
accelerator
.
wait_for_everyone
()
accelerator
.
free_memory
()
del
data_loader
,
all_encoder_outputs
,
all_encoder_lengths
# Now we encode the audio labels with encodec.
# Now we encode the audio labels with encodec.
#######
C
. Encode audio
#######
B
. Encode audio
logger
.
info
(
"*** Encode target audio with encodec ***"
)
logger
.
info
(
"*** Encode target audio with encodec ***"
)
# no need to prepare audio_decoder because used for inference without mixed precision
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder
=
model
.
audio_encoder
if
training_args
.
torch_compile
:
audio_decoder
=
accelerator
.
prepare_model
(
model
.
audio_encoder
,
evaluation_mode
=
True
)
else
:
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
audio_column_name
=
target_audio_column_name
,
feature_extractor_input_name
=
feature_extractor_input_name
,
max_length
=
max_target_length
)
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
audio_column_name
=
target_audio_column_name
,
feature_extractor_input_name
=
feature_extractor_input_name
,
max_length
=
max_target_length
,
padding
=
padding
)
def
apply_audio_decoder
(
batch
):
def
apply_audio_decoder
(
batch
):
len_audio
=
batch
.
pop
(
"len_audio"
)
len_audio
=
batch
.
pop
(
"len_audio"
)
...
@@ -1251,6 +1205,10 @@ def main():
...
@@ -1251,6 +1205,10 @@ def main():
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
def
is_audio_in_length_range
(
length
):
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
return
length
>
min_target_length
and
length
<
max_target_length
...
@@ -1260,11 +1218,41 @@ def main():
...
@@ -1260,11 +1218,41 @@ def main():
num_proc
=
num_workers
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
input_columns
=
[
"target_length"
],
)
)
if
description_column_name
is
not
None
and
data_args
.
max_description_token_length
is
not
None
:
with
accelerator
.
main_process_first
():
# filter description that is shorter than max_text_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
lambda
x
:
len
(
x
)
<
data_args
.
max_description_token_length
,
num_proc
=
num_workers
,
input_columns
=
[
"input_ids"
],
)
if
data_args
.
max_prompt_token_length
is
not
None
:
with
accelerator
.
main_process_first
():
# filter description that is shorter than max_text_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
lambda
x
:
len
(
x
)
<
data_args
.
max_prompt_token_length
,
num_proc
=
num_workers
,
input_columns
=
[
"prompt_input_ids"
],
)
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
,
num_proc
=
min
(
data_args
.
preprocessing_num_workers
,
len
(
vectorized_datasets
[
"eval"
])
-
1
))
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
,
num_proc
=
min
(
data_args
.
preprocessing_num_workers
,
len
(
vectorized_datasets
[
"eval"
])
-
1
))
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
audio_max_length
=
None
if
training_args
.
torch_compile
:
audio_max_length
=
max
(
vectorized_datasets
[
"train"
][
"target_length"
])
with
accelerator
.
main_process_first
():
max_sample
=
vectorized_datasets
[
"train"
].
filter
(
lambda
x
:
x
==
audio_max_length
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
audio_max_length
=
torch
.
tensor
(
max_sample
[
0
][
"labels"
]).
shape
[
1
]
# for large datasets it is advised to run the preprocessing on a
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# single machine first with ``args.preprocessing_only`` since there will mostly likely
...
@@ -1374,7 +1362,8 @@ def main():
...
@@ -1374,7 +1362,8 @@ def main():
# Instantiate custom data collator
# Instantiate custom data collator
data_collator
=
DataCollatorStableSpeechWithPadding
(
data_collator
=
DataCollatorStableSpeechWithPadding
(
audio_feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
audio_feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
,
padding
=
padding
,
prompt_max_length
=
data_args
.
max_prompt_token_length
,
description_max_length
=
data_args
.
max_description_token_length
,
audio_max_length
=
audio_max_length
)
)
...
@@ -1485,7 +1474,7 @@ def main():
...
@@ -1485,7 +1474,7 @@ def main():
):
):
model
.
train
()
model
.
train
()
if
mixed_precision
==
"fp16"
and
not
(
model_args
.
freeze_text_encoder
and
model_args
.
precompute_text_hidden_states
)
:
if
mixed_precision
==
"fp16"
:
# fp16 doesn't work with T5-like models
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
...
@@ -1505,7 +1494,7 @@ def main():
...
@@ -1505,7 +1494,7 @@ def main():
# Define eval fn
# Define eval fn
def
eval_step
(
batch
,
accelerator
,
autocast_kwargs
,):
def
eval_step
(
batch
,
accelerator
,
autocast_kwargs
,):
model
.
eval
()
model
.
eval
()
if
mixed_precision
==
"fp16"
and
not
(
model_args
.
freeze_text_encoder
and
model_args
.
precompute_text_hidden_states
)
:
if
mixed_precision
==
"fp16"
:
# fp16 doesn't work with T5-like models
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
...
@@ -1523,6 +1512,7 @@ def main():
...
@@ -1523,6 +1512,7 @@ def main():
def
generate_step
(
batch
):
def
generate_step
(
batch
):
model
.
eval
()
model
.
eval
()
batch
.
pop
(
"decoder_attention_mask"
,
None
)
output_audios
=
accelerator
.
unwrap_model
(
model
,
keep_fp32_wrapper
=
mixed_precision
!=
"fp16"
).
generate
(
**
batch
,
**
gen_kwargs
)
output_audios
=
accelerator
.
unwrap_model
(
model
,
keep_fp32_wrapper
=
mixed_precision
!=
"fp16"
).
generate
(
**
batch
,
**
gen_kwargs
)
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
return
output_audios
return
output_audios
...
@@ -1626,7 +1616,7 @@ def main():
...
@@ -1626,7 +1616,7 @@ def main():
for
batch
in
tqdm
(
for
batch
in
tqdm
(
validation_dataloader
,
validation_dataloader
,
desc
=
f
"Evaluating..."
,
desc
=
f
"Evaluating
- Inference
..."
,
position
=
2
,
position
=
2
,
disable
=
not
accelerator
.
is_local_main_process
,
disable
=
not
accelerator
.
is_local_main_process
,
):
):
...
@@ -1635,8 +1625,23 @@ def main():
...
@@ -1635,8 +1625,23 @@ def main():
eval_metric
=
accelerator
.
gather_for_metrics
(
eval_metric
)
eval_metric
=
accelerator
.
gather_for_metrics
(
eval_metric
)
eval_metrics
.
append
(
eval_metric
)
eval_metrics
.
append
(
eval_metric
)
if
training_args
.
predict_with_generate
:
validation_dataloader
=
DataLoader
(
vectorized_datasets
[
"eval"
],
collate_fn
=
data_collator
,
batch_size
=
per_device_eval_batch_size
,
drop_last
=
False
,
num_workers
=
training_args
.
dataloader_pin_memory
,
pin_memory
=
training_args
.
dataloader_pin_memory
,
)
validation_dataloader
=
accelerator
.
prepare
(
validation_dataloader
)
# generation
# generation
if
training_args
.
predict_with_generate
:
for
batch
in
tqdm
(
validation_dataloader
,
desc
=
f
"Evaluating - Generation ..."
,
position
=
2
,
disable
=
not
accelerator
.
is_local_main_process
,
):
generated_audios
=
generate_step
(
batch
)
generated_audios
=
generate_step
(
batch
)
# Gather all predictions and targets
# Gather all predictions and targets
# TODO: also add prompt ids
# TODO: also add prompt ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment