Commit 31a54850 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 91542bfa
...@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
ffn_dim=512, ffn_dim=512,
...@@ -27,28 +27,26 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -27,28 +27,26 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
...@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
ffn_dim=512, ffn_dim=512,
...@@ -34,28 +34,26 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -34,28 +34,26 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=3000, # 30 s = 2580 max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12, num_hidden_layers=12,
ffn_dim=4096, ffn_dim=4096,
...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config) ...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=4096, # 30 s = 2580 max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8, num_hidden_layers=8,
ffn_dim=3072, ffn_dim=3072,
...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config) ...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
...@@ -14,7 +14,6 @@ class DACConfig(PretrainedConfig): ...@@ -14,7 +14,6 @@ class DACConfig(PretrainedConfig):
frame_rate: int = 86, frame_rate: int = 86,
**kwargs, **kwargs,
): ):
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.model_bitrate = model_bitrate self.model_bitrate = model_bitrate
self.latent_dim = latent_dim self.latent_dim = latent_dim
......
...@@ -7,9 +7,9 @@ from .configuration_dac import DACConfig ...@@ -7,9 +7,9 @@ from .configuration_dac import DACConfig
from dac.model import DAC from dac.model import DAC
# model doesn't support batching yet # model doesn't support batching yet
class DACModel(PreTrainedModel): class DACModel(PreTrainedModel):
config_class = DACConfig config_class = DACConfig
...@@ -17,12 +17,14 @@ class DACModel(PreTrainedModel): ...@@ -17,12 +17,14 @@ class DACModel(PreTrainedModel):
super().__init__(config) super().__init__(config)
self.model = DAC( self.model = DAC(
n_codebooks = config.num_codebooks, n_codebooks=config.num_codebooks,
latent_dim = config.latent_dim, latent_dim=config.latent_dim,
codebook_size = config.codebook_size, codebook_size=config.codebook_size,
) )
def encode(self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None): def encode(
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
):
""" """
Encodes the input audio waveform into discrete codes. Encodes the input audio waveform into discrete codes.
...@@ -93,13 +95,12 @@ class DACModel(PreTrainedModel): ...@@ -93,13 +95,12 @@ class DACModel(PreTrainedModel):
return EncodecEncoderOutput(encoded_frames, scales) return EncodecEncoderOutput(encoded_frames, scales)
def decode( def decode(
self, self,
audio_codes, audio_codes,
audio_scales, audio_scales,
padding_mask = None, padding_mask=None,
return_dict = None, return_dict=None,
): ):
""" """
Decodes the given frames into an output audio waveform. Decodes the given frames into an output audio waveform.
......
...@@ -46,7 +46,6 @@ from transformers.utils import ( ...@@ -46,7 +46,6 @@ from transformers.utils import (
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
...@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
] ]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" the mask is set to -1, and otherwise setting to the value detailed in the mask."""
...@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): ...@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids return input_ids
def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int):
def build_delay_pattern_mask(
input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1]) input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = ( input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is # we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1: if max_length < 2 * num_codebooks - 1:
...@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask return input_ids, pattern_mask
@dataclass @dataclass
class ParlerTTSUnconditionalInput(ModelOutput): class ParlerTTSUnconditionalInput(ModelOutput):
""" """
...@@ -812,10 +814,24 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -812,10 +814,24 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
) )
if past_key_values is None: if past_key_values is None:
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1) attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype),
],
dim=1,
)
else: else:
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat([prompt_attention_mask, torch.ones((input_shape[0] ,generated_length), device=self.device, dtype=prompt_attention_mask.dtype)], dim=1) attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
),
],
dim=1,
)
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
...@@ -1098,7 +1114,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1098,7 +1114,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if labels is not None: if labels is not None:
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:] logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
...@@ -1107,7 +1123,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1107,7 +1123,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
labels = labels.masked_fill(labels == self.config.bos_token_id, -100) labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
# we use every codebooks token AND one single EOS at the end of each codebooks # we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1,2) != self.config.eos_token_id) & ((labels != -100)) mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100))
# per codebook cross-entropy # per codebook cross-entropy
for codebook in range(self.config.num_codebooks): for codebook in range(self.config.num_codebooks):
...@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
} }
# Ignore copy # Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None): def build_delay_pattern_mask(
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)].reshape( output_ids = output_ids[
batch_size, self.decoder.num_codebooks, -1 (model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)
) & (model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)
].reshape(batch_size, self.decoder.num_codebooks, -1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_ids outputs.sequences = output_ids
...@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder." "Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
) )
if config is None: if config is None:
config = ParlerTTSConfig.from_sub_models_config( config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
text_encoder.config, audio_encoder.config, decoder.config
)
else: else:
if not isinstance(config, self.config_class): if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}") raise ValueError(f"Config: {config} has to be of type {self.config_class}")
...@@ -1588,7 +1605,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1588,7 +1605,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# prompt embeddings # prompt embeddings
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None: if self.text_encoder.get_output_embeddings() is not None:
raise ValueError( raise ValueError(
f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
...@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# TODO: verify it does what's expected # TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1,2) ).transpose(1, 2)
elif decoder_input_ids is None and decoder_inputs_embeds is None: elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder( audio_encoder_outputs = self.audio_encoder(
...@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2,1,1)) prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None: if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2,1)) prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] past_length = past_key_values[0][0].shape[2]
...@@ -2083,7 +2099,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2083,7 +2099,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# we only want to use prompt signal in the 1st generation step but keeping the attention mask # we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None prompt_hidden_states = None
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
...@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings # TODO: now it's possible with prompt_embeddings
...@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask( _, mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
...@@ -2595,10 +2609,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2595,10 +2609,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
max_length=output_ids.shape[1], max_length=output_ids.shape[1],
) )
mask = (mask != generation_config.bos_token_id)&(mask != generation_config.pad_token_id) mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape( output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
batch_size, self.decoder.num_codebooks, -1
)
# append the frame dimension back to the audio codes # append the frame dimension back to the audio codes
output_ids = output_ids[None, ...] output_ids = output_ids[None, ...]
...@@ -2607,7 +2619,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2607,7 +2619,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if audio_scales is None: if audio_scales is None:
audio_scales = [None] * batch_size audio_scales = [None] * batch_size
decode_sequentially = generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids decode_sequentially = (
generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if not decode_sequentially: if not decode_sequentially:
output_values = self.audio_encoder.decode( output_values = self.audio_encoder.decode(
output_ids, output_ids,
...@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values = [] output_values = []
for sample_id in range(batch_size): for sample_id in range(batch_size):
sample = output_ids[:, sample_id] sample = output_ids[:, sample_id]
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0,1)) == 0) sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum()>0: if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask] sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2)) output_values.append(sample.transpose(0, 2))
else: else:
output_values.append(torch.zeros((1,1,1)).to(self.device)) output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).squeeze(-1).squeeze(-1) output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_values outputs.sequences = output_values
......
import dac import dac
# Download a model # Download a model
model_path = dac.utils.download(model_type="44khz") model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path) model = dac.DAC.load(model_path)
...@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig()) ...@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict()) hf_dac.model.load_state_dict(model.state_dict())
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment