Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
e7cc576a
Commit
e7cc576a
authored
Mar 04, 2024
by
Yoach Lacombe
Browse files
add dac config, init, and temporary datasets saving
parent
9bde9933
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
510 additions
and
198 deletions
+510
-198
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+1
-1
example_configs/librispeech_tts_r_100.json
example_configs/librispeech_tts_r_100.json
+7
-8
example_configs/librispeech_tts_r_75M.json
example_configs/librispeech_tts_r_75M.json
+77
-0
example_configs/librispeech_tts_r_dummy_dac.json
example_configs/librispeech_tts_r_dummy_dac.json
+76
-0
init_dummy_model_dac.py
init_dummy_model_dac.py
+67
-0
init_model.py
init_model.py
+9
-3
init_model_75M.py
init_model_75M.py
+67
-0
run_stable_speech_training.py
run_stable_speech_training.py
+206
-186
No files found.
example_configs/librispeech_tts_r.json
View file @
e7cc576a
...
...
@@ -63,7 +63,7 @@
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
600
,
"per_device_eval_batch_size"
:
8
,
"generation_max_length"
:
40
0
,
"generation_max_length"
:
225
0
,
"fp16"
:
false
,
"seed"
:
456
,
...
...
example_configs/librispeech_tts_r_100.json
View file @
e7cc576a
{
"model_name_or_path"
:
"/raid/yoach/tmp/artefacts/small-stable-speech-untrained/"
,
"feature_extractor_name"
:
"
facebook/encode
c_
2
4kh
z
"
,
"feature_extractor_name"
:
"
ylacombe/da
c_
4
4kh
Z_8kbps
"
,
"description_tokenizer_name"
:
"google-t5/t5-small"
,
"prompt_tokenizer_name"
:
"google-t5/t5-small"
,
"push_to_hub"
:
tru
e
,
"push_to_hub"
:
fals
e
,
"hub_model_id"
:
"ylacombe/stable-speech-mini"
,
"report_to"
:
[
"wandb"
],
"overwrite_output_dir"
:
tru
e
,
"overwrite_output_dir"
:
fals
e
,
"output_dir"
:
"/raid/yoach/tmp/artefacts/training-mini/"
,
...
...
@@ -34,7 +34,7 @@
"add_audio_samples_to_wandb"
:
true
,
"id_column_name"
:
"id"
,
"preprocessing_num_workers"
:
1
,
"preprocessing_num_workers"
:
8
,
"pad_token_id"
:
1024
,
...
...
@@ -45,7 +45,7 @@
"num_train_epochs"
:
15
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
true
,
"per_device_train_batch_size"
:
40
,
"per_device_train_batch_size"
:
28
,
"learning_rate"
:
1e-4
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
...
...
@@ -63,11 +63,10 @@
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
30
00
,
"save_steps"
:
3000
,
"eval_steps"
:
25
00
,
"save_steps"
:
2499
,
"per_device_eval_batch_size"
:
8
,
"generation_max_length"
:
400
,
"audio_encode_per_device_eval_batch_size"
:
32
,
"dtype"
:
"float16"
,
...
...
example_configs/librispeech_tts_r_75M.json
0 → 100644
View file @
e7cc576a
{
"model_name_or_path"
:
"/raid/yoach/tmp/artefacts/stable-speech-untrained-75M/"
,
"save_to_disk"
:
"/raid/yoach/tmp/artefacts/libritts_r_1k_hours_processed/"
,
"preprocessing_only"
:
false
,
"feature_extractor_name"
:
"ylacombe/dac_44khZ_8kbps"
,
"description_tokenizer_name"
:
"google/t5-v1_1-small"
,
"prompt_tokenizer_name"
:
"google/t5-v1_1-small"
,
"push_to_hub"
:
false
,
"hub_model_id"
:
"ylacombe/stable-speech-75M"
,
"report_to"
:
[
"wandb"
],
"overwrite_output_dir"
:
false
,
"output_dir"
:
"/raid/yoach/tmp/artefacts/training-75M-0.1/"
,
"train_dataset_name"
:
"blabble-io/libritts_r+blabble-io/libritts_r+blabble-io/libritts_r"
,
"train_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated+stable-speech/libritts-r-tags-and-text-generated+stable-speech/libritts-r-tags-and-text-generated"
,
"train_dataset_config_name"
:
"clean+clean+other"
,
"train_split_name"
:
"train.clean.360+train.clean.100+train.other.500"
,
"eval_dataset_name"
:
"blabble-io/libritts_r+blabble-io/libritts_r"
,
"eval_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated+stable-speech/libritts-r-tags-and-text-generated"
,
"eval_dataset_config_name"
:
"clean+other"
,
"eval_split_name"
:
"test.clean+test.other"
,
"target_audio_column_name"
:
"audio"
,
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"max_eval_samples"
:
24
,
"max_duration_in_seconds"
:
35
,
"min_duration_in_seconds"
:
2.0
,
"add_audio_samples_to_wandb"
:
true
,
"id_column_name"
:
"id"
,
"preprocessing_num_workers"
:
16
,
"pad_token_id"
:
1024
,
"decoder_start_token_id"
:
1025
,
"do_train"
:
true
,
"num_train_epochs"
:
1
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
true
,
"per_device_train_batch_size"
:
28
,
"learning_rate"
:
1e-4
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
"weight_decay"
:
0.03
,
"lr_scheduler_type"
:
"constant_with_warmup"
,
"warmup_steps"
:
5000
,
"logging_steps"
:
102
,
"freeze_text_encoder"
:
true
,
"do_eval"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
2500
,
"save_steps"
:
2499
,
"per_device_eval_batch_size"
:
1
,
"audio_encode_per_device_eval_batch_size"
:
24
,
"dtype"
:
"bfloat16"
,
"seed"
:
456
,
"dataloader_num_workers"
:
16
}
example_configs/librispeech_tts_r_dummy_dac.json
0 → 100644
View file @
e7cc576a
{
"model_name_or_path"
:
"/raid/yoach/tmp/artefacts/tiny-dac-model/"
,
"save_to_disk"
:
"/raid/yoach/tmp/artefacts/small_experiment_dataset/"
,
"feature_extractor_name"
:
"ylacombe/dac_44khZ_8kbps"
,
"description_tokenizer_name"
:
"google-t5/t5-small"
,
"prompt_tokenizer_name"
:
"google-t5/t5-small"
,
"push_to_hub"
:
false
,
"hub_model_id"
:
"stable-speech-mini"
,
"report_to"
:
[
"wandb"
],
"overwrite_output_dir"
:
true
,
"output_dir"
:
"/raid/yoach/tmp/artefacts/training/"
,
"train_dataset_name"
:
"blabble-io/libritts_r"
,
"train_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated"
,
"train_dataset_config_name"
:
"clean"
,
"train_split_name"
:
"train.clean.360"
,
"eval_dataset_name"
:
"blabble-io/libritts_r"
,
"eval_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated"
,
"eval_dataset_config_name"
:
"clean"
,
"eval_split_name"
:
"train.clean.360"
,
"target_audio_column_name"
:
"audio"
,
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
4
,
"max_eval_samples"
:
4
,
"max_duration_in_seconds"
:
30
,
"min_duration_in_seconds"
:
1.0
,
"add_audio_samples_to_wandb"
:
true
,
"id_column_name"
:
"id"
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
1024
,
"decoder_start_token_id"
:
1025
,
"do_train"
:
true
,
"num_train_epochs"
:
180
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
false
,
"per_device_train_batch_size"
:
2
,
"learning_rate"
:
1e-3
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
"weight_decay"
:
0.1
,
"lr_scheduler_type"
:
"cosine"
,
"warmup_ratio"
:
0.1
,
"freeze_text_encoder"
:
true
,
"do_eval"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
30
,
"per_device_eval_batch_size"
:
2
,
"generation_max_length"
:
800
,
"do_sample"
:
false
,
"logging_steps"
:
15
,
"dtype"
:
"float32"
,
"seed"
:
456
,
"dataloader_num_workers"
:
8
}
init_dummy_model_dac.py
0 → 100644
View file @
e7cc576a
from
stable_speech
import
StableSpeechConfig
,
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
,
StableSpeechDecoderConfig
from
transformers
import
AutoConfig
from
transformers
import
AutoModel
from
transformers
import
AutoConfig
,
AutoModel
from
stable_speech
import
DACConfig
,
DACModel
AutoConfig
.
register
(
"dac"
,
DACConfig
)
AutoModel
.
register
(
DACConfig
,
DACModel
)
text_model
=
"google-t5/t5-small"
encodec_version
=
"ylacombe/dac_44khZ_8kbps"
num_codebooks
=
9
t5
=
AutoConfig
.
from_pretrained
(
text_model
)
encodec
=
AutoConfig
.
from_pretrained
(
encodec_version
)
encodec_vocab_size
=
encodec
.
codebook_size
decoder_config
=
StableSpeechDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
4
,
ffn_dim
=
512
,
num_attention_heads
=
8
,
layerdrop
=
0.0
,
use_cache
=
True
,
activation_function
=
"gelu"
,
hidden_size
=
512
,
dropout
=
0.0
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
pad_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
bos_token_id
=
encodec_vocab_size
+
1
,
num_codebooks
=
num_codebooks
,
)
# TODO: ?? how to make it stop ?
decoder
=
StableSpeechForCausalLM
(
decoder_config
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
model
=
StableSpeechForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
text_model
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
vocab_size
=
t5
.
vocab_size
)
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
do_sample
=
False
# True
model
.
generation_config
.
guidance_scale
=
1
# 3.0
model
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/tiny-dac-model/"
)
\ No newline at end of file
init_model.py
View file @
e7cc576a
...
...
@@ -3,9 +3,15 @@ from transformers import T5Config, EncodecConfig
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
,
AutoModel
from
stable_speech
import
DACConfig
,
DACModel
AutoConfig
.
register
(
"dac"
,
DACConfig
)
AutoModel
.
register
(
DACConfig
,
DACModel
)
text_model
=
"google-t5/t5-small"
encodec_version
=
"
facebook/encode
c_
2
4kh
z
"
num_codebooks
=
8
encodec_version
=
"
ylacombe/da
c_
4
4kh
Z_8kbps
"
num_codebooks
=
9
t5
=
AutoConfig
.
from_pretrained
(
text_model
)
...
...
@@ -16,7 +22,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config
=
StableSpeechDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
225
0
,
# 30 s
max_position_embeddings
=
300
0
,
# 30 s
= 2580
num_hidden_layers
=
12
,
ffn_dim
=
4096
,
num_attention_heads
=
16
,
...
...
init_model_75M.py
0 → 100644
View file @
e7cc576a
from
stable_speech
import
StableSpeechConfig
,
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
,
StableSpeechDecoderConfig
from
transformers
import
T5Config
,
EncodecConfig
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
,
AutoModel
from
stable_speech
import
DACConfig
,
DACModel
AutoConfig
.
register
(
"dac"
,
DACConfig
)
AutoModel
.
register
(
DACConfig
,
DACModel
)
text_model
=
"google/t5-v1_1-small"
encodec_version
=
"ylacombe/dac_44khZ_8kbps"
num_codebooks
=
9
t5
=
AutoConfig
.
from_pretrained
(
text_model
)
encodec
=
AutoConfig
.
from_pretrained
(
encodec_version
)
encodec_vocab_size
=
encodec
.
codebook_size
decoder_config
=
StableSpeechDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
4096
,
# 30 s = 2580
num_hidden_layers
=
8
,
ffn_dim
=
3072
,
num_attention_heads
=
12
,
layerdrop
=
0.0
,
use_cache
=
True
,
activation_function
=
"gelu"
,
hidden_size
=
768
,
dropout
=
0.0
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
pad_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
bos_token_id
=
encodec_vocab_size
+
1
,
num_codebooks
=
num_codebooks
,
)
decoder
=
StableSpeechForCausalLM
(
decoder_config
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder_small/"
)
model
=
StableSpeechForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
text_model
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder_small/"
,
vocab_size
=
t5
.
vocab_size
)
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
do_sample
=
False
# True
model
.
generation_config
.
guidance_scale
=
1
# 3.0
model
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/stable-speech-untrained-75M/"
)
\ No newline at end of file
run_stable_speech_training.py
View file @
e7cc576a
...
...
@@ -397,7 +397,8 @@ class DataTrainingArguments:
"Whether to only do data preprocessing and skip training. This is especially useful when data"
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
" can consequently be loaded in distributed training"
" can consequently be loaded in distributed training."
" In this training script, `save_to_disk` must be set to the path in which the dataset should be saved. "
)
},
)
...
...
@@ -442,6 +443,12 @@ class DataTrainingArguments:
default
=
"stable-speech"
,
metadata
=
{
"help"
:
"The name of the wandb project."
},
)
save_to_disk
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
@
dataclass
class
StableSpeechTrainingArguments
(
Seq2SeqTrainingArguments
):
...
...
@@ -781,61 +788,70 @@ def main():
# Set seed before initializing model.
set_seed
(
training_args
.
seed
)
num_workers
=
data_args
.
preprocessing_num_workers
# 1. First, let's load the dataset
raw_datasets
=
DatasetDict
()
num_workers
=
data_args
.
preprocessing_num_workers
columns_to_keep
=
{
"target_audio_column_name"
:
data_args
.
target_audio_column_name
,
"prompt_column_name"
:
data_args
.
prompt_column_name
}
if
data_args
.
description_column_name
is
not
None
:
columns_to_keep
[
"description_column_nam"
]
=
data_args
.
description_column_name
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
train_metadata_dataset_name
,
splits
=
data_args
.
train_split_name
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
for
key
in
columns_to_keep
:
if
columns_to_keep
[
key
]
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--
{
key
}
'
{
columns_to_keep
[
key
]
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
f
" Make sure to set `--
{
key
}
` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
raw_datasets
[
"train"
].
select
(
range
(
data_args
.
max_train_samples
))
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
eval_dataset_name
if
data_args
.
eval_dataset_name
else
data_args
.
train_dataset_name
,
data_args
.
eval_dataset_config_name
if
data_args
.
eval_dataset_config_name
else
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
eval_metadata_dataset_name
,
splits
=
data_args
.
eval_split_name
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if
data_args
.
save_to_disk
is
not
None
:
os
.
makedirs
(
data_args
.
save_to_disk
,
exist_ok
=
True
)
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed
=
len
(
os
.
listdir
(
data_args
.
save_to_disk
))
>
0
if
dataset_was_precomputed
:
vectorized_datasets
=
datasets
.
load_from_disk
(
data_args
.
save_to_disk
)
else
:
raw_datasets
=
DatasetDict
()
columns_to_keep
=
{
"target_audio_column_name"
:
data_args
.
target_audio_column_name
,
"prompt_column_name"
:
data_args
.
prompt_column_name
}
if
data_args
.
description_column_name
is
not
None
:
columns_to_keep
[
"description_column_nam"
]
=
data_args
.
description_column_name
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
train_metadata_dataset_name
,
splits
=
data_args
.
train_split_name
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
for
key
in
columns_to_keep
:
if
columns_to_keep
[
key
]
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--
{
key
}
'
{
columns_to_keep
[
key
]
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
f
" Make sure to set `--
{
key
}
` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
raw_datasets
[
"train"
].
select
(
range
(
data_args
.
max_train_samples
))
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
eval_dataset_name
if
data_args
.
eval_dataset_name
else
data_args
.
train_dataset_name
,
data_args
.
eval_dataset_config_name
if
data_args
.
eval_dataset_config_name
else
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
eval_metadata_dataset_name
,
splits
=
data_args
.
eval_split_name
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
.
values
(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
# 2. Next, let's load the config as we might need it to create
...
...
@@ -921,160 +937,164 @@ def main():
num_codebooks
=
model
.
decoder
.
config
.
num_codebooks
bandwidth
=
model_args
.
bandwidth
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
target_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
sampling_rate
)
)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the texts.
def
pass_through_processors
(
batch
):
# load audio
if
description_column_name
is
not
None
:
text
=
batch
[
description_column_name
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
())[
"input_ids"
]
if
prompt_column_name
is
not
None
:
text
=
batch
[
prompt_column_name
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
text
.
strip
())[
"input_ids"
]
# load audio
target_sample
=
batch
[
target_audio_column_name
]
labels
=
feature_extractor
(
target_sample
[
"array"
],
sampling_rate
=
target_sample
[
"sampling_rate"
])
batch
[
"labels"
]
=
labels
[
"input_values"
]
# take length of raw audio waveform
batch
[
"target_length"
]
=
len
(
target_sample
[
"array"
].
squeeze
())
return
batch
with
accelerator
.
main_process_first
():
vectorized_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
num_proc
=
num_workers
,
desc
=
"preprocess datasets"
,
if
not
dataset_was_precomputed
:
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
target_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
sampling_rate
)
)
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_audio_in_length_range
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the texts.
def
pass_through_processors
(
batch
):
# load audio
if
description_column_name
is
not
None
:
text
=
batch
[
description_column_name
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
())[
"input_ids"
]
if
prompt_column_name
is
not
None
:
text
=
batch
[
prompt_column_name
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
text
.
strip
())[
"input_ids"
]
# load audio
target_sample
=
batch
[
target_audio_column_name
]
labels
=
feature_extractor
(
target_sample
[
"array"
],
sampling_rate
=
target_sample
[
"sampling_rate"
])
batch
[
"labels"
]
=
labels
[
"input_values"
]
# take length of raw audio waveform
batch
[
"target_length"
]
=
len
(
target_sample
[
"array"
].
squeeze
())
return
batch
with
accelerator
.
main_process_first
():
vectorized_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
num_proc
=
num_workers
,
desc
=
"preprocess datasets"
,
)
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
logger
.
info
(
"*** Encode target audio with encodec ***"
)
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder
=
model
.
audio_encoder
# filter data that is shorter than min_target_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_audio_in_length_range
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
def
apply_audio_decoder
(
batch
):
len_audio
=
batch
.
pop
(
"len_audio"
)
audio_decoder
.
to
(
batch
[
"input_values"
].
device
).
eval
()
with
torch
.
no_grad
():
labels
=
audio_decoder
.
encode
(
**
batch
,
bandwidth
=
bandwidth
)[
"audio_codes"
]
output
=
{}
output
[
"len_audio"
]
=
len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output
[
"labels"
]
=
labels
.
squeeze
(
0
).
transpose
(
1
,
2
)
output
[
"ratio"
]
=
torch
.
ones_like
(
len_audio
)
*
labels
.
shape
[
-
1
]
/
len_audio
.
max
()
return
output
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
vectorized_datasets
[
split
],
batch_size
=
training_args
.
audio_encode_per_device_eval_batch_size
,
collate_fn
=
encoder_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
)
data_loader
=
accelerator
.
prepare
(
data_loader
)
logger
.
info
(
"*** Encode target audio with encodec ***"
)
all_generated_labels
=
[]
all_ratios
=
[]
all_lens
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generate_labels
=
apply_audio_decoder
(
batch
)
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
# (1, codebooks, seq_len) where seq_len=1
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
def
postprocess_dataset
(
input_ids
,
prompt_input_ids
,
idx
):
# (1, codebooks, seq_len)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
# labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
,
delay_pattern_mask
=
build_delay_pattern_mask
(
labels
,
bos_token_id
=
audio_encoder_bos_token_id
,
pad_token_id
=
audio_encoder_eos_token_id
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
num_codebooks
=
num_codebooks
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels
=
torch
.
where
(
delay_pattern_mask
==-
1
,
audio_encoder_eos_token_id
,
delay_pattern_mask
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output
=
{
"labels"
:
labels
[:,
1
:].
cpu
()}
output
[
"input_ids"
]
=
input_ids
output
[
"prompt_input_ids"
]
=
prompt_input_ids
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
def
apply_audio_decoder
(
batch
):
len_audio
=
batch
.
pop
(
"len_audio"
)
audio_decoder
.
to
(
batch
[
"input_values"
].
device
).
eval
()
with
torch
.
no_grad
():
labels
=
audio_decoder
.
encode
(
**
batch
,
bandwidth
=
bandwidth
)[
"audio_codes"
]
output
=
{}
output
[
"len_audio"
]
=
len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output
[
"labels"
]
=
labels
.
squeeze
(
0
).
transpose
(
1
,
2
)
output
[
"ratio"
]
=
torch
.
ones_like
(
len_audio
)
*
labels
.
shape
[
-
1
]
/
len_audio
.
max
()
return
output
# TODO: done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
num_workers
,
input_columns
=
[
"input_ids"
,
"prompt_input_ids"
],
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
writer_batch_size
=
200
,
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
vectorized_datasets
[
split
],
batch_size
=
training_args
.
audio_encode_per_device_eval_batch_size
,
collate_fn
=
encoder_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
)
data_loader
=
accelerator
.
prepare
(
data_loader
)
accelerator
.
free_memory
()
del
generate_labels
all_generated_labels
=
[]
all_ratios
=
[]
all_lens
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generate_labels
=
apply_audio_decoder
(
batch
)
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
# (1, codebooks, seq_len) where seq_len=1
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
def
postprocess_dataset
(
input_ids
,
prompt_input_ids
,
idx
):
# (1, codebooks, seq_len)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
# labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
,
delay_pattern_mask
=
build_delay_pattern_mask
(
labels
,
bos_token_id
=
audio_encoder_bos_token_id
,
pad_token_id
=
audio_encoder_eos_token_id
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
num_codebooks
=
num_codebooks
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels
=
torch
.
where
(
delay_pattern_mask
==-
1
,
audio_encoder_eos_token_id
,
delay_pattern_mask
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output
=
{
"labels"
:
labels
[:,
1
:].
cpu
()}
output
[
"input_ids"
]
=
input_ids
output
[
"prompt_input_ids"
]
=
prompt_input_ids
return
output
# TODO: done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
1
,
# this one is resource consuming if many processor.
input_columns
=
[
"input_ids"
,
"prompt_input_ids"
],
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
writer_batch_size
=
200
,
)
accelerator
.
free_memory
()
del
generate_labels
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
)
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if
data_args
.
preprocessing_only
:
# TODO: save to disk in this step instead of something else ??
logger
.
info
(
f
"Data preprocessing finished. Files cached at
{
vectorized_datasets
.
cache_files
}
"
)
if
data_args
.
preprocessing_only
and
data_args
.
save_to_disk
is
None
:
raise
ValueError
(
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
)
elif
data_args
.
preprocessing_only
:
logger
.
info
(
f
"Data preprocessing finished. Files save at
{
data_args
.
save_to_disk
}
"
)
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment