Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
31a54850
Commit
31a54850
authored
Apr 08, 2024
by
Yoach Lacombe
Browse files
make style
parent
91542bfa
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
497 additions
and
404 deletions
+497
-404
init_dummy_model.py
init_dummy_model.py
+8
-10
init_dummy_model_dac.py
init_dummy_model_dac.py
+8
-10
init_model.py
init_model.py
+9
-11
init_model_75M.py
init_model_75M.py
+9
-11
parler_tts/__init__.py
parler_tts/__init__.py
+7
-2
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+1
-1
parler_tts/dac_wrapper/__init__.py
parler_tts/dac_wrapper/__init__.py
+1
-1
parler_tts/dac_wrapper/configuration_dac.py
parler_tts/dac_wrapper/configuration_dac.py
+3
-4
parler_tts/dac_wrapper/modeling_dac.py
parler_tts/dac_wrapper/modeling_dac.py
+24
-23
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+87
-68
push_dac_to_hub.py
push_dac_to_hub.py
+3
-2
run_parler_tts_training.py
run_parler_tts_training.py
+337
-261
No files found.
init_dummy_model.py
View file @
31a54850
...
@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size
...
@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config
=
ParlerTTSDecoderConfig
(
decoder_config
=
ParlerTTSDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
2048
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
4
,
num_hidden_layers
=
4
,
ffn_dim
=
512
,
ffn_dim
=
512
,
...
@@ -27,28 +27,26 @@ decoder_config = ParlerTTSDecoderConfig(
...
@@ -27,28 +27,26 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout
=
0.0
,
activation_dropout
=
0.0
,
pad_token_id
=
encodec_vocab_size
,
pad_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
bos_token_id
=
encodec_vocab_size
+
1
,
bos_token_id
=
encodec_vocab_size
+
1
,
num_codebooks
=
num_codebooks
,
num_codebooks
=
num_codebooks
,
)
)
# TODO: ?? how to make it stop ?
# TODO: ?? how to make it stop ?
decoder
=
ParlerTTSForCausalLM
(
decoder_config
)
decoder
=
ParlerTTSForCausalLM
(
decoder_config
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
text_model
,
text_encoder_pretrained_model_name_or_path
=
text_model
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
vocab_size
=
t5
.
vocab_size
vocab_size
=
t5
.
vocab_size
,
)
)
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
...
...
init_dummy_model_dac.py
View file @
31a54850
...
@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size
...
@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config
=
ParlerTTSDecoderConfig
(
decoder_config
=
ParlerTTSDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
2048
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
4
,
num_hidden_layers
=
4
,
ffn_dim
=
512
,
ffn_dim
=
512
,
...
@@ -34,28 +34,26 @@ decoder_config = ParlerTTSDecoderConfig(
...
@@ -34,28 +34,26 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout
=
0.0
,
activation_dropout
=
0.0
,
pad_token_id
=
encodec_vocab_size
,
pad_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
bos_token_id
=
encodec_vocab_size
+
1
,
bos_token_id
=
encodec_vocab_size
+
1
,
num_codebooks
=
num_codebooks
,
num_codebooks
=
num_codebooks
,
)
)
# TODO: ?? how to make it stop ?
# TODO: ?? how to make it stop ?
decoder
=
ParlerTTSForCausalLM
(
decoder_config
)
decoder
=
ParlerTTSForCausalLM
(
decoder_config
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
text_model
,
text_encoder_pretrained_model_name_or_path
=
text_model
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
vocab_size
=
t5
.
vocab_size
vocab_size
=
t5
.
vocab_size
,
)
)
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
...
...
init_model.py
View file @
31a54850
...
@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
...
@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config
=
ParlerTTSDecoderConfig
(
decoder_config
=
ParlerTTSDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
3000
,
# 30 s = 2580
max_position_embeddings
=
3000
,
# 30 s = 2580
num_hidden_layers
=
12
,
num_hidden_layers
=
12
,
ffn_dim
=
4096
,
ffn_dim
=
4096
,
...
@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
...
@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout
=
0.0
,
activation_dropout
=
0.0
,
pad_token_id
=
encodec_vocab_size
,
pad_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
bos_token_id
=
encodec_vocab_size
+
1
,
bos_token_id
=
encodec_vocab_size
+
1
,
num_codebooks
=
num_codebooks
,
num_codebooks
=
num_codebooks
,
)
)
...
@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
...
@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder/"
)
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
text_model
,
text_encoder_pretrained_model_name_or_path
=
text_model
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder/"
,
vocab_size
=
t5
.
vocab_size
vocab_size
=
t5
.
vocab_size
,
)
)
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
...
...
init_model_75M.py
View file @
31a54850
...
@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
...
@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config
=
ParlerTTSDecoderConfig
(
decoder_config
=
ParlerTTSDecoderConfig
(
vocab_size
=
encodec_vocab_size
+
1
,
vocab_size
=
encodec_vocab_size
+
1
,
max_position_embeddings
=
4096
,
# 30 s = 2580
max_position_embeddings
=
4096
,
# 30 s = 2580
num_hidden_layers
=
8
,
num_hidden_layers
=
8
,
ffn_dim
=
3072
,
ffn_dim
=
3072
,
...
@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
...
@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout
=
0.0
,
activation_dropout
=
0.0
,
pad_token_id
=
encodec_vocab_size
,
pad_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
eos_token_id
=
encodec_vocab_size
,
bos_token_id
=
encodec_vocab_size
+
1
,
bos_token_id
=
encodec_vocab_size
+
1
,
num_codebooks
=
num_codebooks
,
num_codebooks
=
num_codebooks
,
)
)
...
@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
...
@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder_small/"
)
decoder
.
save_pretrained
(
"/raid/yoach/tmp/artefacts/decoder_small/"
)
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
model
=
ParlerTTSForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
text_model
,
text_encoder_pretrained_model_name_or_path
=
text_model
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
audio_encoder_pretrained_model_name_or_path
=
encodec_version
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder_small/"
,
decoder_pretrained_model_name_or_path
=
"/raid/yoach/tmp/artefacts/decoder_small/"
,
vocab_size
=
t5
.
vocab_size
vocab_size
=
t5
.
vocab_size
,
)
)
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
decoder_start_token_id
=
encodec_vocab_size
+
1
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
pad_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
model
.
generation_config
.
eos_token_id
=
encodec_vocab_size
...
...
parler_tts/__init__.py
View file @
31a54850
from
.configuration_parler_tts
import
ParlerTTSConfig
,
ParlerTTSDecoderConfig
from
.configuration_parler_tts
import
ParlerTTSConfig
,
ParlerTTSDecoderConfig
from
.modeling_parler_tts
import
ParlerTTSForCausalLM
,
ParlerTTSForConditionalGeneration
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
from
.modeling_parler_tts
import
(
ParlerTTSForCausalLM
,
ParlerTTSForConditionalGeneration
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
,
)
from
.dac_wrapper
import
DACConfig
,
DACModel
from
.dac_wrapper
import
DACConfig
,
DACModel
parler_tts/configuration_parler_tts.py
View file @
31a54850
parler_tts/dac_wrapper/__init__.py
View file @
31a54850
parler_tts/dac_wrapper/configuration_dac.py
View file @
31a54850
...
@@ -14,7 +14,6 @@ class DACConfig(PretrainedConfig):
...
@@ -14,7 +14,6 @@ class DACConfig(PretrainedConfig):
frame_rate
:
int
=
86
,
frame_rate
:
int
=
86
,
**
kwargs
,
**
kwargs
,
):
):
self
.
codebook_size
=
codebook_size
self
.
codebook_size
=
codebook_size
self
.
model_bitrate
=
model_bitrate
self
.
model_bitrate
=
model_bitrate
self
.
latent_dim
=
latent_dim
self
.
latent_dim
=
latent_dim
...
...
parler_tts/dac_wrapper/modeling_dac.py
View file @
31a54850
...
@@ -7,9 +7,9 @@ from .configuration_dac import DACConfig
...
@@ -7,9 +7,9 @@ from .configuration_dac import DACConfig
from
dac.model
import
DAC
from
dac.model
import
DAC
# model doesn't support batching yet
# model doesn't support batching yet
class
DACModel
(
PreTrainedModel
):
class
DACModel
(
PreTrainedModel
):
config_class
=
DACConfig
config_class
=
DACConfig
...
@@ -17,12 +17,14 @@ class DACModel(PreTrainedModel):
...
@@ -17,12 +17,14 @@ class DACModel(PreTrainedModel):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
model
=
DAC
(
self
.
model
=
DAC
(
n_codebooks
=
config
.
num_codebooks
,
n_codebooks
=
config
.
num_codebooks
,
latent_dim
=
config
.
latent_dim
,
latent_dim
=
config
.
latent_dim
,
codebook_size
=
config
.
codebook_size
,
codebook_size
=
config
.
codebook_size
,
)
)
def
encode
(
self
,
input_values
,
padding_mask
=
None
,
bandwidth
=
None
,
return_dict
=
None
,
n_quantizers
=
None
,
sample_rate
=
None
):
def
encode
(
self
,
input_values
,
padding_mask
=
None
,
bandwidth
=
None
,
return_dict
=
None
,
n_quantizers
=
None
,
sample_rate
=
None
):
"""
"""
Encodes the input audio waveform into discrete codes.
Encodes the input audio waveform into discrete codes.
...
@@ -93,13 +95,12 @@ class DACModel(PreTrainedModel):
...
@@ -93,13 +95,12 @@ class DACModel(PreTrainedModel):
return
EncodecEncoderOutput
(
encoded_frames
,
scales
)
return
EncodecEncoderOutput
(
encoded_frames
,
scales
)
def
decode
(
def
decode
(
self
,
self
,
audio_codes
,
audio_codes
,
audio_scales
,
audio_scales
,
padding_mask
=
None
,
padding_mask
=
None
,
return_dict
=
None
,
return_dict
=
None
,
):
):
"""
"""
Decodes the given frames into an output audio waveform.
Decodes the given frames into an output audio waveform.
...
...
parler_tts/modeling_parler_tts.py
View file @
31a54850
...
@@ -46,7 +46,6 @@ from transformers.utils import (
...
@@ -46,7 +46,6 @@ from transformers.utils import (
from
.configuration_parler_tts
import
ParlerTTSConfig
,
ParlerTTSDecoderConfig
from
.configuration_parler_tts
import
ParlerTTSConfig
,
ParlerTTSDecoderConfig
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers.generation.streamers
import
BaseStreamer
from
transformers.generation.streamers
import
BaseStreamer
...
@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
...
@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
]
]
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
...
@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
...
@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
input_ids
=
torch
.
where
(
decoder_pad_token_mask
==
-
1
,
input_ids
,
decoder_pad_token_mask
)
input_ids
=
torch
.
where
(
decoder_pad_token_mask
==
-
1
,
input_ids
,
decoder_pad_token_mask
)
return
input_ids
return
input_ids
def
build_delay_pattern_mask
(
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad_token_id
:
int
,
max_length
:
int
,
num_codebooks
:
int
):
def
build_delay_pattern_mask
(
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad_token_id
:
int
,
max_length
:
int
,
num_codebooks
:
int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...
@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
...
@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids
=
input_ids
.
reshape
(
-
1
,
num_codebooks
,
input_ids
.
shape
[
-
1
])
input_ids
=
input_ids
.
reshape
(
-
1
,
num_codebooks
,
input_ids
.
shape
[
-
1
])
bsz
,
num_codebooks
,
seq_len
=
input_ids
.
shape
bsz
,
num_codebooks
,
seq_len
=
input_ids
.
shape
input_ids_shifted
=
(
input_ids_shifted
=
torch
.
ones
((
bsz
,
num_codebooks
,
max_length
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
-
1
torch
.
ones
((
bsz
,
num_codebooks
,
max_length
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
-
1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if
max_length
<
2
*
num_codebooks
-
1
:
if
max_length
<
2
*
num_codebooks
-
1
:
...
@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
...
@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids
=
input_ids
[...,
:
first_start_id
].
reshape
(
bsz
*
num_codebooks
,
-
1
)
input_ids
=
input_ids
[...,
:
first_start_id
].
reshape
(
bsz
*
num_codebooks
,
-
1
)
return
input_ids
,
pattern_mask
return
input_ids
,
pattern_mask
@
dataclass
@
dataclass
class
ParlerTTSUnconditionalInput
(
ModelOutput
):
class
ParlerTTSUnconditionalInput
(
ModelOutput
):
"""
"""
...
@@ -812,10 +814,24 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
...
@@ -812,10 +814,24 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
)
if
past_key_values
is
None
:
if
past_key_values
is
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)],
dim
=
1
)
attention_mask
=
torch
.
cat
(
[
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
),
],
dim
=
1
,
)
else
:
else
:
generated_length
=
past_key_values_length
-
prompt_attention_mask
.
shape
[
1
]
+
1
generated_length
=
past_key_values_length
-
prompt_attention_mask
.
shape
[
1
]
+
1
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
((
input_shape
[
0
]
,
generated_length
),
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)],
dim
=
1
)
attention_mask
=
torch
.
cat
(
[
prompt_attention_mask
,
torch
.
ones
(
(
input_shape
[
0
],
generated_length
),
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
),
],
dim
=
1
,
)
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
=
_prepare_4d_causal_attention_mask
(
...
@@ -1098,7 +1114,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1098,7 +1114,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if
labels
is
not
None
:
if
labels
is
not
None
:
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits
=
lm_logits
[:,:,
-
labels
.
shape
[
1
]:]
logits
=
lm_logits
[:,
:,
-
labels
.
shape
[
1
]
:]
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
...
@@ -1107,7 +1123,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1107,7 +1123,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
# we use every codebooks token AND one single EOS at the end of each codebooks
# we use every codebooks token AND one single EOS at the end of each codebooks
mask
=
(
input_ids
.
transpose
(
1
,
2
)
!=
self
.
config
.
eos_token_id
)
&
((
labels
!=
-
100
))
mask
=
(
input_ids
.
transpose
(
1
,
2
)
!=
self
.
config
.
eos_token_id
)
&
((
labels
!=
-
100
))
# per codebook cross-entropy
# per codebook cross-entropy
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
...
@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
}
}
# Ignore copy
# Ignore copy
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad_token_id
:
int
,
max_length
:
int
=
None
):
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad_token_id
:
int
,
max_length
:
int
=
None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...
@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
eos_token_id
)].
reshape
(
output_ids
=
output_ids
[
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
)
&
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
eos_token_id
)
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
if
generation_config
.
return_dict_in_generate
:
if
generation_config
.
return_dict_in_generate
:
outputs
.
sequences
=
output_ids
outputs
.
sequences
=
output_ids
...
@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
)
)
if
config
is
None
:
if
config
is
None
:
config
=
ParlerTTSConfig
.
from_sub_models_config
(
config
=
ParlerTTSConfig
.
from_sub_models_config
(
text_encoder
.
config
,
audio_encoder
.
config
,
decoder
.
config
)
text_encoder
.
config
,
audio_encoder
.
config
,
decoder
.
config
)
else
:
else
:
if
not
isinstance
(
config
,
self
.
config_class
):
if
not
isinstance
(
config
,
self
.
config_class
):
raise
ValueError
(
f
"Config:
{
config
}
has to be of type
{
self
.
config_class
}
"
)
raise
ValueError
(
f
"Config:
{
config
}
has to be of type
{
self
.
config_class
}
"
)
...
@@ -1588,7 +1605,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -1588,7 +1605,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# prompt embeddings
# prompt embeddings
self
.
embed_prompts
=
nn
.
Embedding
(
config
.
vocab_size
,
self
.
decoder
.
config
.
hidden_size
)
self
.
embed_prompts
=
nn
.
Embedding
(
config
.
vocab_size
,
self
.
decoder
.
config
.
hidden_size
)
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
raise
ValueError
(
raise
ValueError
(
f
"The encoder
{
self
.
text_encoder
}
should not have a LM Head. Please use a model without and LM Head"
f
"The encoder
{
self
.
text_encoder
}
should not have a LM Head. Please use a model without and LM Head"
...
@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# TODO: verify it does what's expected
# TODO: verify it does what's expected
decoder_input_ids
=
shift_tokens_right
(
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
).
transpose
(
1
,
2
)
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
audio_encoder_outputs
=
self
.
audio_encoder
(
audio_encoder_outputs
=
self
.
audio_encoder
(
...
@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
decoder_attention_mask
is
not
None
:
if
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
decoder_attention_mask
.
repeat
((
2
,
1
))
decoder_attention_mask
=
decoder_attention_mask
.
repeat
((
2
,
1
))
if
prompt_hidden_states
is
not
None
:
if
prompt_hidden_states
is
not
None
:
prompt_hidden_states
=
prompt_hidden_states
.
repeat
((
2
,
1
,
1
))
prompt_hidden_states
=
prompt_hidden_states
.
repeat
((
2
,
1
,
1
))
if
prompt_attention_mask
is
not
None
:
if
prompt_attention_mask
is
not
None
:
prompt_attention_mask
=
prompt_attention_mask
.
repeat
((
2
,
1
))
prompt_attention_mask
=
prompt_attention_mask
.
repeat
((
2
,
1
))
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
...
@@ -2083,7 +2099,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2083,7 +2099,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states
=
None
prompt_hidden_states
=
None
return
{
return
{
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"encoder_outputs"
:
encoder_outputs
,
"encoder_outputs"
:
encoder_outputs
,
...
@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return
model_kwargs
return
model_kwargs
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
# TODO: now it's possible with prompt_embeddings
...
@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_
,
mask
=
self
.
decoder
.
build_delay_pattern_mask
(
_
,
mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
...
@@ -2595,10 +2609,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2595,10 +2609,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
max_length
=
output_ids
.
shape
[
1
],
max_length
=
output_ids
.
shape
[
1
],
)
)
mask
=
(
mask
!=
generation_config
.
bos_token_id
)
&
(
mask
!=
generation_config
.
pad_token_id
)
mask
=
(
mask
!=
generation_config
.
bos_token_id
)
&
(
mask
!=
generation_config
.
pad_token_id
)
output_ids
=
output_ids
[
mask
].
reshape
(
output_ids
=
output_ids
[
mask
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
# append the frame dimension back to the audio codes
# append the frame dimension back to the audio codes
output_ids
=
output_ids
[
None
,
...]
output_ids
=
output_ids
[
None
,
...]
...
@@ -2607,7 +2619,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2607,7 +2619,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
audio_scales
is
None
:
if
audio_scales
is
None
:
audio_scales
=
[
None
]
*
batch_size
audio_scales
=
[
None
]
*
batch_size
decode_sequentially
=
generation_config
.
bos_token_id
in
output_ids
or
generation_config
.
pad_token_id
in
output_ids
or
generation_config
.
eos_token_id
in
output_ids
decode_sequentially
=
(
generation_config
.
bos_token_id
in
output_ids
or
generation_config
.
pad_token_id
in
output_ids
or
generation_config
.
eos_token_id
in
output_ids
)
if
not
decode_sequentially
:
if
not
decode_sequentially
:
output_values
=
self
.
audio_encoder
.
decode
(
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
output_ids
,
...
@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values
=
[]
output_values
=
[]
for
sample_id
in
range
(
batch_size
):
for
sample_id
in
range
(
batch_size
):
sample
=
output_ids
[:,
sample_id
]
sample
=
output_ids
[:,
sample_id
]
sample_mask
=
(
(
sample
>=
self
.
audio_encoder
.
config
.
codebook_size
).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample_mask
=
(
sample
>=
self
.
audio_encoder
.
config
.
codebook_size
).
sum
(
dim
=
(
0
,
1
))
==
0
if
sample_mask
.
sum
()
>
0
:
if
sample_mask
.
sum
()
>
0
:
sample
=
sample
[:,
:,
sample_mask
]
sample
=
sample
[:,
:,
sample_mask
]
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
else
:
else
:
output_values
.
append
(
torch
.
zeros
((
1
,
1
,
1
)).
to
(
self
.
device
))
output_values
.
append
(
torch
.
zeros
((
1
,
1
,
1
)).
to
(
self
.
device
))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
squeeze
(
-
1
).
squeeze
(
-
1
)
output_values
=
(
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
)
.
squeeze
(
-
1
)
.
squeeze
(
-
1
)
)
if
generation_config
.
return_dict_in_generate
:
if
generation_config
.
return_dict_in_generate
:
outputs
.
sequences
=
output_values
outputs
.
sequences
=
output_values
...
...
push_dac_to_hub.py
View file @
31a54850
import
dac
import
dac
# Download a model
# Download a model
model_path
=
dac
.
utils
.
download
(
model_type
=
"44khz"
)
model_path
=
dac
.
utils
.
download
(
model_type
=
"44khz"
)
model
=
dac
.
DAC
.
load
(
model_path
)
model
=
dac
.
DAC
.
load
(
model_path
)
...
@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig())
...
@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig())
hf_dac
.
model
.
load_state_dict
(
model
.
state_dict
())
hf_dac
.
model
.
load_state_dict
(
model
.
state_dict
())
from
transformers
import
AutoConfig
,
AutoModel
from
transformers
import
AutoConfig
,
AutoModel
AutoConfig
.
register
(
"dac"
,
DACConfig
)
AutoConfig
.
register
(
"dac"
,
DACConfig
)
AutoModel
.
register
(
DACConfig
,
DACModel
)
AutoModel
.
register
(
DACConfig
,
DACModel
)
...
...
run_parler_tts_training.py
View file @
31a54850
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment