Commit 31a54850 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 91542bfa
...@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
ffn_dim=512, ffn_dim=512,
...@@ -27,28 +27,26 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -27,28 +27,26 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
...@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
ffn_dim=512, ffn_dim=512,
...@@ -34,28 +34,26 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -34,28 +34,26 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=3000, # 30 s = 2580 max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12, num_hidden_layers=12,
ffn_dim=4096, ffn_dim=4096,
...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config) ...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -21,7 +21,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=4096, # 30 s = 2580 max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8, num_hidden_layers=8,
ffn_dim=3072, ffn_dim=3072,
...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -35,7 +35,7 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config) ...@@ -44,17 +44,15 @@ decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
......
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
...@@ -14,7 +14,6 @@ class DACConfig(PretrainedConfig): ...@@ -14,7 +14,6 @@ class DACConfig(PretrainedConfig):
frame_rate: int = 86, frame_rate: int = 86,
**kwargs, **kwargs,
): ):
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.model_bitrate = model_bitrate self.model_bitrate = model_bitrate
self.latent_dim = latent_dim self.latent_dim = latent_dim
......
...@@ -7,9 +7,9 @@ from .configuration_dac import DACConfig ...@@ -7,9 +7,9 @@ from .configuration_dac import DACConfig
from dac.model import DAC from dac.model import DAC
# model doesn't support batching yet # model doesn't support batching yet
class DACModel(PreTrainedModel): class DACModel(PreTrainedModel):
config_class = DACConfig config_class = DACConfig
...@@ -17,12 +17,14 @@ class DACModel(PreTrainedModel): ...@@ -17,12 +17,14 @@ class DACModel(PreTrainedModel):
super().__init__(config) super().__init__(config)
self.model = DAC( self.model = DAC(
n_codebooks = config.num_codebooks, n_codebooks=config.num_codebooks,
latent_dim = config.latent_dim, latent_dim=config.latent_dim,
codebook_size = config.codebook_size, codebook_size=config.codebook_size,
) )
def encode(self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None): def encode(
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
):
""" """
Encodes the input audio waveform into discrete codes. Encodes the input audio waveform into discrete codes.
...@@ -93,13 +95,12 @@ class DACModel(PreTrainedModel): ...@@ -93,13 +95,12 @@ class DACModel(PreTrainedModel):
return EncodecEncoderOutput(encoded_frames, scales) return EncodecEncoderOutput(encoded_frames, scales)
def decode( def decode(
self, self,
audio_codes, audio_codes,
audio_scales, audio_scales,
padding_mask = None, padding_mask=None,
return_dict = None, return_dict=None,
): ):
""" """
Decodes the given frames into an output audio waveform. Decodes the given frames into an output audio waveform.
......
...@@ -46,7 +46,6 @@ from transformers.utils import ( ...@@ -46,7 +46,6 @@ from transformers.utils import (
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
...@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
] ]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" the mask is set to -1, and otherwise setting to the value detailed in the mask."""
...@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): ...@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids return input_ids
def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int):
def build_delay_pattern_mask(
input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1]) input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = ( input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is # we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1: if max_length < 2 * num_codebooks - 1:
...@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask return input_ids, pattern_mask
@dataclass @dataclass
class ParlerTTSUnconditionalInput(ModelOutput): class ParlerTTSUnconditionalInput(ModelOutput):
""" """
...@@ -812,10 +814,24 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -812,10 +814,24 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
) )
if past_key_values is None: if past_key_values is None:
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1) attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype),
],
dim=1,
)
else: else:
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat([prompt_attention_mask, torch.ones((input_shape[0] ,generated_length), device=self.device, dtype=prompt_attention_mask.dtype)], dim=1) attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
),
],
dim=1,
)
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
...@@ -1098,7 +1114,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1098,7 +1114,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if labels is not None: if labels is not None:
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:] logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
...@@ -1107,7 +1123,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1107,7 +1123,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
labels = labels.masked_fill(labels == self.config.bos_token_id, -100) labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
# we use every codebooks token AND one single EOS at the end of each codebooks # we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1,2) != self.config.eos_token_id) & ((labels != -100)) mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100))
# per codebook cross-entropy # per codebook cross-entropy
for codebook in range(self.config.num_codebooks): for codebook in range(self.config.num_codebooks):
...@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
} }
# Ignore copy # Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None): def build_delay_pattern_mask(
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)].reshape( output_ids = output_ids[
batch_size, self.decoder.num_codebooks, -1 (model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)
) & (model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)
].reshape(batch_size, self.decoder.num_codebooks, -1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_ids outputs.sequences = output_ids
...@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder." "Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
) )
if config is None: if config is None:
config = ParlerTTSConfig.from_sub_models_config( config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
text_encoder.config, audio_encoder.config, decoder.config
)
else: else:
if not isinstance(config, self.config_class): if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}") raise ValueError(f"Config: {config} has to be of type {self.config_class}")
...@@ -1588,7 +1605,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1588,7 +1605,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# prompt embeddings # prompt embeddings
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None: if self.text_encoder.get_output_embeddings() is not None:
raise ValueError( raise ValueError(
f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
...@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# TODO: verify it does what's expected # TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1,2) ).transpose(1, 2)
elif decoder_input_ids is None and decoder_inputs_embeds is None: elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder( audio_encoder_outputs = self.audio_encoder(
...@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2,1,1)) prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None: if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2,1)) prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] past_length = past_key_values[0][0].shape[2]
...@@ -2083,7 +2099,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2083,7 +2099,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# we only want to use prompt signal in the 1st generation step but keeping the attention mask # we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None prompt_hidden_states = None
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
...@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings # TODO: now it's possible with prompt_embeddings
...@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask( _, mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
...@@ -2595,10 +2609,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2595,10 +2609,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
max_length=output_ids.shape[1], max_length=output_ids.shape[1],
) )
mask = (mask != generation_config.bos_token_id)&(mask != generation_config.pad_token_id) mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape( output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
batch_size, self.decoder.num_codebooks, -1
)
# append the frame dimension back to the audio codes # append the frame dimension back to the audio codes
output_ids = output_ids[None, ...] output_ids = output_ids[None, ...]
...@@ -2607,7 +2619,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2607,7 +2619,11 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if audio_scales is None: if audio_scales is None:
audio_scales = [None] * batch_size audio_scales = [None] * batch_size
decode_sequentially = generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids decode_sequentially = (
generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if not decode_sequentially: if not decode_sequentially:
output_values = self.audio_encoder.decode( output_values = self.audio_encoder.decode(
output_ids, output_ids,
...@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values = [] output_values = []
for sample_id in range(batch_size): for sample_id in range(batch_size):
sample = output_ids[:, sample_id] sample = output_ids[:, sample_id]
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0,1)) == 0) sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum()>0: if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask] sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2)) output_values.append(sample.transpose(0, 2))
else: else:
output_values.append(torch.zeros((1,1,1)).to(self.device)) output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).squeeze(-1).squeeze(-1) output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_values outputs.sequences = output_values
......
import dac import dac
# Download a model # Download a model
model_path = dac.utils.download(model_type="44khz") model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path) model = dac.DAC.load(model_path)
...@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig()) ...@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict()) hf_dac.model.load_state_dict(model.state_dict())
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
......
...@@ -65,6 +65,7 @@ from transformers.integrations import is_wandb_available ...@@ -65,6 +65,7 @@ from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from parler_tts import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -73,7 +74,12 @@ from accelerate import Accelerator ...@@ -73,7 +74,12 @@ from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig, apply_delay_pattern_mask, build_delay_pattern_mask from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
if is_wandb_available(): if is_wandb_available():
from wandb import Audio from wandb import Audio
...@@ -90,8 +96,10 @@ logger = logging.getLogger(__name__) ...@@ -90,8 +96,10 @@ logger = logging.getLogger(__name__)
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$") _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
def get_last_checkpoint(folder): def get_last_checkpoint(folder):
content = os.listdir(folder) content = os.listdir(folder)
checkpoints = [ checkpoints = [
...@@ -103,6 +111,7 @@ def get_last_checkpoint(folder): ...@@ -103,6 +111,7 @@ def get_last_checkpoint(folder):
return return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0]))) return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]: def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest.""" """Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = [] ordering_and_checkpoint_path = []
...@@ -118,6 +127,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[ ...@@ -118,6 +127,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted return checkpoints_sorted
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None: def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
"""Helper function to delete old checkpoints.""" """Helper function to delete old checkpoints."""
if save_total_limit is None or save_total_limit <= 0: if save_total_limit is None or save_total_limit <= 0:
...@@ -133,6 +143,7 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix ...@@ -133,6 +143,7 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True) shutil.rmtree(checkpoint, ignore_errors=True)
def log_metric( def log_metric(
accelerator, accelerator,
metrics: Dict, metrics: Dict,
...@@ -152,6 +163,7 @@ def log_metric( ...@@ -152,6 +163,7 @@ def log_metric(
log_metrics[f"{prefix}/learning_rate"] = learning_rate log_metrics[f"{prefix}/learning_rate"] = learning_rate
accelerator.log(log_metrics, step=step) accelerator.log(log_metrics, step=step)
def log_pred( def log_pred(
accelerator, accelerator,
pred_descriptions: List[str], pred_descriptions: List[str],
...@@ -182,22 +194,27 @@ def log_pred( ...@@ -182,22 +194,27 @@ def log_pred(
) )
# wandb can only loads 100 audios per step # wandb can only loads 100 audios per step
wandb_tracker.log({ wandb_tracker.log(
{
"Speech samples": [ "Speech samples": [
Audio( Audio(
audio, audio,
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}", caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
sample_rate=sampling_rate, sample_rate=sampling_rate,
) )
for (i, audio) in enumerate(audios[:min(len(audios), 100)]) for (i, audio) in enumerate(audios[: min(len(audios), 100)])
]}, ]
step=step) },
step=step,
)
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
""" """
# TODO: pretrain from scratch # TODO: pretrain from scratch
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
...@@ -212,7 +229,8 @@ class ModelArguments: ...@@ -212,7 +229,8 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
) )
prompt_tokenizer_name: Optional[str] = field( prompt_tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"} default=None,
metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
...@@ -256,7 +274,6 @@ class ModelArguments: ...@@ -256,7 +274,6 @@ class ModelArguments:
) )
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
...@@ -333,11 +350,11 @@ class DataTrainingArguments: ...@@ -333,11 +350,11 @@ class DataTrainingArguments:
default="audio", default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
) )
description_column_name: str = field( #TODO description_column_name: str = field( # TODO
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
) )
prompt_column_name: str = field( #TODO prompt_column_name: str = field( # TODO
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
) )
...@@ -382,28 +399,31 @@ class DataTrainingArguments: ...@@ -382,28 +399,31 @@ class DataTrainingArguments:
default=500, metadata={"help": "If set, max description lengths in number of characters."} default=500, metadata={"help": "If set, max description lengths in number of characters."}
) )
max_prompt_token_length: int = field( max_prompt_token_length: int = field(
default=None, metadata={ default=None,
metadata={
"help": ( "help": (
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens." "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`." "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
) )
} },
) )
max_description_token_length: int = field( max_description_token_length: int = field(
default=None, metadata={ default=None,
metadata={
"help": ( "help": (
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens." "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`." "Also, used to set maximum desription token length if `pad_to_max_length=True`."
) )
} },
) )
pad_to_max_length: bool = field( pad_to_max_length: bool = field(
default=False, metadata={ default=False,
metadata={
"help": ( "help": (
"If `True`, pad audio, prompt and description to a maximum length set with respectively " "If `True`, pad audio, prompt and description to a maximum length set with respectively "
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`." "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
) )
} },
) )
preprocessing_only: bool = field( preprocessing_only: bool = field(
default=False, default=False,
...@@ -444,16 +464,9 @@ class DataTrainingArguments: ...@@ -444,16 +464,9 @@ class DataTrainingArguments:
) )
add_audio_samples_to_wandb: bool = field( add_audio_samples_to_wandb: bool = field(
default=False, default=False,
metadata={ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."
}
)
id_column_name: str = field(
default=None,
metadata={
"help": "id column name."
}
) )
id_column_name: str = field(default=None, metadata={"help": "id column name."})
wandb_project: str = field( wandb_project: str = field(
default="parler-speech", default="parler-speech",
metadata={"help": "The name of the wandb project."}, metadata={"help": "The name of the wandb project."},
...@@ -462,23 +475,15 @@ class DataTrainingArguments: ...@@ -462,23 +475,15 @@ class DataTrainingArguments:
default=None, default=None,
metadata={ metadata={
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it." "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
} },
)
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
) )
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
pad_to_multiple_of: Optional[int] = field( pad_to_multiple_of: Optional[int] = field(
default=2, default=2,
metadata={ metadata={"help": ("Pad to multiple of for tokenizers.")},
"help": (
"Pad to multiple of for tokenizers."
)
},
) )
@dataclass @dataclass
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field( dtype: Optional[str] = field(
...@@ -492,13 +497,10 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ...@@ -492,13 +497,10 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
) )
audio_encode_per_device_eval_batch_size: int = field( audio_encode_per_device_eval_batch_size: int = field(
default=8, default=8,
metadata={ metadata={"help": ("TODO")},
"help": (
"TODO"
)
},
) )
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
""" """
...@@ -512,7 +514,6 @@ class DataCollatorEncodecWithPadding: ...@@ -512,7 +514,6 @@ class DataCollatorEncodecWithPadding:
max_length: Optional[int] = None max_length: Optional[int] = None
padding: Optional[str] = "longest" padding: Optional[str] = "longest"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
...@@ -564,26 +565,37 @@ class DataCollatorParlerTTSWithPadding: ...@@ -564,26 +565,37 @@ class DataCollatorParlerTTSWithPadding:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
labels = [torch.tensor(feature["labels"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks) # (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
if self.audio_max_length is not None and self.padding=="max_length": if self.audio_max_length is not None and self.padding == "max_length":
labels = torch.nn.functional.pad(labels, pad=(0,0,0,max(self.audio_max_length-labels.shape[1], 0))) labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)))
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.description_max_length) input_ids = self.description_tokenizer.pad(
input_ids,
return_tensors="pt",
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
max_length=self.description_max_length,
)
batch= {"labels":labels, **input_ids} batch = {"labels": labels, **input_ids}
if self.audio_max_length is not None and self.padding=="max_length": if self.audio_max_length is not None and self.padding == "max_length":
# if we do torch.compile, we need to also specify the attention_mask # if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype) decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.prompt_max_length) prompt_input_ids = self.prompt_tokenizer.pad(
prompt_input_ids,
return_tensors="pt",
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
max_length=self.prompt_max_length,
)
batch["prompt_input_ids"] = prompt_input_ids["input_ids"] batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids: if "attention_mask" in prompt_input_ids:
...@@ -591,13 +603,16 @@ class DataCollatorParlerTTSWithPadding: ...@@ -591,13 +603,16 @@ class DataCollatorParlerTTSWithPadding:
if self.feature_extractor_input_name in features[0]: if self.feature_extractor_input_name in features[0]:
# TODO (YL): verify that it works - IMPORTANT -> probably not working # TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values = [{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features] input_values = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
input_values = self.feature_extractor.pad(input_values, return_tensors="pt") input_values = self.feature_extractor.pad(input_values, return_tensors="pt")
batch[self.feature_extractor_input_name: input_values] batch[self.feature_extractor_input_name : input_values]
return batch return batch
def convert_dataset_str_to_list( def convert_dataset_str_to_list(
dataset_names, dataset_names,
dataset_config_names, dataset_config_names,
...@@ -660,7 +675,7 @@ def load_multiple_datasets( ...@@ -660,7 +675,7 @@ def load_multiple_datasets(
accelerator: Accelerator, accelerator: Accelerator,
dataset_names: Union[List, str], dataset_names: Union[List, str],
dataset_config_names: Union[List, str], dataset_config_names: Union[List, str],
metadata_dataset_names: Optional[str]=None, metadata_dataset_names: Optional[str] = None,
splits: Optional[Union[List, str]] = None, splits: Optional[Union[List, str]] = None,
label_column_names: Optional[List] = None, label_column_names: Optional[List] = None,
stopping_strategy: Optional[str] = "first_exhausted", stopping_strategy: Optional[str] = "first_exhausted",
...@@ -699,13 +714,13 @@ def load_multiple_datasets( ...@@ -699,13 +714,13 @@ def load_multiple_datasets(
if sampling_rate is not None and audio_column_name is not None: if sampling_rate is not None and audio_column_name is not None:
# resample target audio # resample target audio
dataset = dataset.cast_column( dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
metadata_dataset_name = dataset_dict["metadata_dataset_name"] metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None: if metadata_dataset_name is not None:
logger.info(f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}') logger.info(
f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}'
)
metadata_dataset = load_dataset( metadata_dataset = load_dataset(
metadata_dataset_name, metadata_dataset_name,
dataset_dict["config"], dataset_dict["config"],
...@@ -743,7 +758,9 @@ def load_multiple_datasets( ...@@ -743,7 +758,9 @@ def load_multiple_datasets(
# We might have applied some transformations to the prompts (e.g punctuation restoration) # We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset # so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names: if prompt_column_name in dataset.column_names:
logger.info(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']") logger.info(
f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']"
)
dataset.remove_columns(prompt_column_name) dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
...@@ -752,8 +769,18 @@ def load_multiple_datasets( ...@@ -752,8 +769,18 @@ def load_multiple_datasets(
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if (
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}") len(
dataset.filter(
lambda id1, id2: id1 != id2,
input_columns=[id_column_name, f"metadata_{id_column_name}"],
)
)
!= 0
):
raise ValueError(
f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}"
)
dataset_features = dataset.features.keys() dataset_features = dataset.features.keys()
...@@ -779,7 +806,6 @@ def load_multiple_datasets( ...@@ -779,7 +806,6 @@ def load_multiple_datasets(
return interleaved_dataset return interleaved_dataset
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
...@@ -804,8 +830,14 @@ def main(): ...@@ -804,8 +830,14 @@ def main():
else: else:
mixed_precision = "no" mixed_precision = "no"
if data_args.pad_to_max_length and (data_args.max_duration_in_seconds is None or data_args.max_prompt_token_length is None or data_args.max_description_token_length is None): if data_args.pad_to_max_length and (
raise ValueError("`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`") data_args.max_duration_in_seconds is None
or data_args.max_prompt_token_length is None
or data_args.max_description_token_length is None
):
raise ValueError(
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
)
padding = "max_length" if data_args.pad_to_max_length else "longest" padding = "max_length" if data_args.pad_to_max_length else "longest"
...@@ -813,7 +845,7 @@ def main(): ...@@ -813,7 +845,7 @@ def main():
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile: if training_args.torch_compile:
# TODO(YL): add more compile modes? # TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) #reduce-overhead kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
...@@ -823,7 +855,9 @@ def main(): ...@@ -823,7 +855,9 @@ def main():
kwargs_handlers=kwargs_handlers, kwargs_handlers=kwargs_handlers,
) )
accelerator.init_trackers(project_name=data_args.wandb_project, config={ accelerator.init_trackers(
project_name=data_args.wandb_project,
config={
"learning_rate": training_args.learning_rate, "learning_rate": training_args.learning_rate,
"model_name_or_path": model_args.model_name_or_path, "model_name_or_path": model_args.model_name_or_path,
"num_train_epochs": training_args.num_train_epochs, "num_train_epochs": training_args.num_train_epochs,
...@@ -831,16 +865,16 @@ def main(): ...@@ -831,16 +865,16 @@ def main():
"per_device_train_batch_size": training_args.per_device_train_batch_size, "per_device_train_batch_size": training_args.per_device_train_batch_size,
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes, "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
"mixed_precision": mixed_precision, "mixed_precision": mixed_precision,
"lr_scheduler_type":training_args.lr_scheduler_type, "lr_scheduler_type": training_args.lr_scheduler_type,
"warmup_steps":training_args.warmup_steps, "warmup_steps": training_args.warmup_steps,
"freeze_text_encoder":model_args.freeze_text_encoder, "freeze_text_encoder": model_args.freeze_text_encoder,
"max_duration_in_seconds":data_args.max_duration_in_seconds, "max_duration_in_seconds": data_args.max_duration_in_seconds,
"weight_decay": training_args.weight_decay, "weight_decay": training_args.weight_decay,
"adam_beta1": training_args.adam_beta1, "adam_beta1": training_args.adam_beta1,
"adam_beta2": training_args.adam_beta2, "adam_beta2": training_args.adam_beta2,
"temperature": model_args.temperature, "temperature": model_args.temperature,
}) },
)
# Detecting last checkpoint and eventually continue from last checkpoint # Detecting last checkpoint and eventually continue from last checkpoint
last_checkpoint = None last_checkpoint = None
...@@ -857,7 +891,6 @@ def main(): ...@@ -857,7 +891,6 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
) )
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -880,7 +913,6 @@ def main(): ...@@ -880,7 +913,6 @@ def main():
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model. # Set seed before initializing model.
...@@ -920,7 +952,9 @@ def main(): ...@@ -920,7 +952,9 @@ def main():
) )
if model_args.use_fast_tokenizer: if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235") logger.warning(
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
)
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
...@@ -938,7 +972,7 @@ def main(): ...@@ -938,7 +972,7 @@ def main():
columns_to_keep = { columns_to_keep = {
"target_audio_column_name": data_args.target_audio_column_name, "target_audio_column_name": data_args.target_audio_column_name,
"prompt_column_name": data_args.prompt_column_name "prompt_column_name": data_args.prompt_column_name,
} }
if data_args.description_column_name is not None: if data_args.description_column_name is not None:
columns_to_keep["description_column_name"] = data_args.description_column_name columns_to_keep["description_column_name"] = data_args.description_column_name
...@@ -977,7 +1011,9 @@ def main(): ...@@ -977,7 +1011,9 @@ def main():
raw_datasets["eval"] = load_multiple_datasets( raw_datasets["eval"] = load_multiple_datasets(
accelerator, accelerator,
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name, data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name if data_args.eval_dataset_config_name else data_args.train_dataset_config_name, data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
metadata_dataset_names=data_args.eval_metadata_dataset_name, metadata_dataset_names=data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name, splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -991,8 +1027,9 @@ def main(): ...@@ -991,8 +1027,9 @@ def main():
) )
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# 3. Next, let's load the config. # 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch # TODO(YL): add the option to create the config from scratch
...@@ -1005,10 +1042,16 @@ def main(): ...@@ -1005,10 +1042,16 @@ def main():
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder # TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config.update({ config.update(
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id, {
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id, "pad_token_id": model_args.pad_token_id
}) if model_args.pad_token_id is not None
else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None
else model.config.decoder_start_token_id,
}
)
# create model + TODO(YL): not from_pretrained probably # create model + TODO(YL): not from_pretrained probably
model = ParlerTTSForConditionalGeneration.from_pretrained( model = ParlerTTSForConditionalGeneration.from_pretrained(
...@@ -1087,7 +1130,7 @@ def main(): ...@@ -1087,7 +1130,7 @@ def main():
# We use Accelerate to perform distributed inference # We use Accelerate to perform distributed inference
# T5 doesn't support fp16 # T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16")) autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Now we encode the audio labels with encodec. # Now we encode the audio labels with encodec.
####### B. Encode audio ####### B. Encode audio
...@@ -1101,7 +1144,13 @@ def main(): ...@@ -1101,7 +1144,13 @@ def main():
else: else:
audio_decoder = model.audio_encoder audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length,padding=padding) encoder_data_collator = DataCollatorEncodecWithPadding(
feature_extractor,
audio_column_name=target_audio_column_name,
feature_extractor_input_name=feature_extractor_input_name,
max_length=max_target_length,
padding=padding,
)
def apply_audio_decoder(batch): def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio") len_audio = batch.pop("len_audio")
...@@ -1111,7 +1160,7 @@ def main(): ...@@ -1111,7 +1160,7 @@ def main():
output = {} output = {}
output["len_audio"] = len_audio output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks) # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1,2) output["labels"] = labels.squeeze(0).transpose(1, 2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max() output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
return output return output
...@@ -1133,10 +1182,10 @@ def main(): ...@@ -1133,10 +1182,10 @@ def main():
generate_labels = accelerator.gather_for_metrics(generate_labels) generate_labels = accelerator.gather_for_metrics(generate_labels)
if accelerator.is_main_process: if accelerator.is_main_process:
lab = generate_labels["labels"].cpu().transpose(1,2).to(torch.int16) lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze() rat = generate_labels["ratio"].cpu().squeeze()
lens = generate_labels["len_audio"].cpu().squeeze() lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, :int(ratio*length)] for (l, ratio, length) in zip(lab, rat, lens)] lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab) all_generated_labels.extend(lab)
all_lens.extend(lens) all_lens.extend(lens)
...@@ -1146,7 +1195,10 @@ def main(): ...@@ -1146,7 +1195,10 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split), num_proc=1 if split == "eval" else data_args.preprocessing_num_workers) tmp_labels.save_to_disk(
os.path.join(data_args.temporary_save_to_disk, split),
num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
del all_generated_labels del all_generated_labels
...@@ -1154,18 +1206,19 @@ def main(): ...@@ -1154,18 +1206,19 @@ def main():
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels): def postprocess_dataset(labels):
# (1, codebooks, seq_len) # (1, codebooks, seq_len)
labels = torch.tensor(labels).unsqueeze(0) labels = torch.tensor(labels).unsqueeze(0)
# add bos # add bos
labels = torch.cat([bos_labels, labels], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(labels, labels, delay_pattern_mask = build_delay_pattern_mask(
labels,
bos_token_id=audio_encoder_bos_token_id, bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_eos_token_id, pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks, max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks) num_codebooks=num_codebooks,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS # to take care of EOS
...@@ -1174,14 +1227,13 @@ def main(): ...@@ -1174,14 +1227,13 @@ def main():
# - [B, B, c, d, E, E, E] # - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E] # - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E] # - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask==-1, audio_encoder_eos_token_id, delay_pattern_mask) labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD) # we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:]} output = {"labels": labels[:, 1:]}
return output return output
# TODO(YL): done multiple times, how to deal with it. # TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
...@@ -1191,11 +1243,9 @@ def main(): ...@@ -1191,11 +1243,9 @@ def main():
desc="Postprocessing labeling", desc="Postprocessing labeling",
) )
accelerator.free_memory() accelerator.free_memory()
del generate_labels, all_lens del generate_labels, all_lens
with accelerator.main_process_first(): with accelerator.main_process_first():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets. # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
...@@ -1231,10 +1281,12 @@ def main(): ...@@ -1231,10 +1281,12 @@ def main():
if data_args.save_to_disk is not None and not dataset_was_precomputed: if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process: if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"])-1)) vectorized_datasets.save_to_disk(
data_args.save_to_disk,
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
)
logger.info(f"Dataset saved at {data_args.save_to_disk}") logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None audio_max_length = None
if training_args.torch_compile: if training_args.torch_compile:
audio_max_length = max(vectorized_datasets["train"]["target_length"]) audio_max_length = max(vectorized_datasets["train"]["target_length"])
...@@ -1252,12 +1304,13 @@ def main(): ...@@ -1252,12 +1304,13 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset # cached dataset
if data_args.preprocessing_only and data_args.save_to_disk is None: if data_args.preprocessing_only and data_args.save_to_disk is None:
raise ValueError("`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally.") raise ValueError(
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
)
elif data_args.preprocessing_only: elif data_args.preprocessing_only:
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}") logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
return return
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
# Let's use word CLAP similary and WER metrics as our evaluation metrics, # Let's use word CLAP similary and WER metrics as our evaluation metrics,
...@@ -1267,12 +1320,13 @@ def main(): ...@@ -1267,12 +1320,13 @@ def main():
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech") clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech")
metric = evaluate.load("wer") metric = evaluate.load("wer")
def clap_similarity(texts, audios, device): def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device) clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device) clap.to(device)
with torch.no_grad(): with torch.no_grad():
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)) text_features = clap.get_text_features(
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
)
audio_features = clap.get_audio_features(clap_inputs["input_features"]) audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
...@@ -1283,13 +1337,17 @@ def main(): ...@@ -1283,13 +1337,17 @@ def main():
def wer(prompts, audios, device): def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device) asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios], batch_size=int(training_args.per_device_eval_batch_size)) transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(training_args.per_device_eval_batch_size),
)
word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]) word_error = 100 * metric.compute(
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions] return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer} eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(audios, descriptions, prompts, device="cpu"): def compute_metrics(audios, descriptions, prompts, device="cpu"):
...@@ -1297,9 +1355,7 @@ def main(): ...@@ -1297,9 +1355,7 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.cpu().numpy() for a in audios]
results = { results = {"clap": eval_methods["clap"](texts, audios, device)}
"clap": eval_methods["clap"](texts, audios, device)
}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device) word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
results["wer"] = word_error results["wer"] = word_error
...@@ -1312,7 +1368,6 @@ def main(): ...@@ -1312,7 +1368,6 @@ def main():
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
if training_args.max_steps < 0: if training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps) steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
...@@ -1325,15 +1380,13 @@ def main(): ...@@ -1325,15 +1380,13 @@ def main():
steps_per_epoch = total_train_steps steps_per_epoch = total_train_steps
if training_args.eval_steps is None: if training_args.eval_steps is None:
logger.info( logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
f"eval_steps is not set, evaluating at the end of each epoch"
)
eval_steps = steps_per_epoch eval_steps = steps_per_epoch
else: else:
eval_steps = training_args.eval_steps eval_steps = training_args.eval_steps
# T5 doesn't support fp16 # T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16")) autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Define optimizer, LR scheduler, collator # Define optimizer, LR scheduler, collator
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
...@@ -1354,11 +1407,17 @@ def main(): ...@@ -1354,11 +1407,17 @@ def main():
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorParlerTTSWithPadding( data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of, audio_feature_extractor=feature_extractor,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length feature_extractor_input_name=feature_extractor_input_name,
prompt_tokenizer=prompt_tokenizer,
description_tokenizer=description_tokenizer,
pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding,
prompt_max_length=data_args.max_prompt_token_length,
description_max_length=data_args.max_description_token_length,
audio_max_length=audio_max_length,
) )
# Prepare everything with accelerate # Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
...@@ -1387,7 +1446,6 @@ def main(): ...@@ -1387,7 +1446,6 @@ def main():
elif last_checkpoint is not None: elif last_checkpoint is not None:
checkpoint = last_checkpoint checkpoint = last_checkpoint
if accelerator.is_main_process: if accelerator.is_main_process:
if training_args.push_to_hub: if training_args.push_to_hub:
# Retrieve of infer repo_name # Retrieve of infer repo_name
...@@ -1412,17 +1470,21 @@ def main(): ...@@ -1412,17 +1470,21 @@ def main():
# only the main process saves them # only the main process saves them
if accelerator.is_main_process: if accelerator.is_main_process:
# save feature extractor, tokenizer and config # save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name): if (
model_args.prompt_tokenizer_name is None
and model_args.description_tokenizer_name
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
):
prompt_tokenizer.save_pretrained(training_args.output_dir) prompt_tokenizer.save_pretrained(training_args.output_dir)
else: else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.") logger.warning(
"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
)
prompt_tokenizer.save_pretrained(training_args.output_dir) prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir) feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir)
if checkpoint is not None: if checkpoint is not None:
accelerator.load_state(checkpoint) accelerator.load_state(checkpoint)
# Find num steps and epoch from saved state string pattern # Find num steps and epoch from saved state string pattern
...@@ -1470,9 +1532,13 @@ def main(): ...@@ -1470,9 +1532,13 @@ def main():
# fp16 doesn't work with T5-like models # fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed": if training_args.parallel_mode.value != "distributed":
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = model.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
else: else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
outputs = model(**batch) outputs = model(**batch)
...@@ -1484,7 +1550,11 @@ def main(): ...@@ -1484,7 +1550,11 @@ def main():
return ce_loss, metrics return ce_loss, metrics
# Define eval fn # Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,): def eval_step(
batch,
accelerator,
autocast_kwargs,
):
eval_model = model if not training_args.torch_compile else model._orig_mod eval_model = model if not training_args.torch_compile else model._orig_mod
eval_model.eval() eval_model.eval()
...@@ -1493,9 +1563,13 @@ def main(): ...@@ -1493,9 +1563,13 @@ def main():
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
with torch.no_grad(): with torch.no_grad():
if training_args.parallel_mode.value != "distributed" or training_args.torch_compile: if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
encoder_outputs = eval_model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = eval_model.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
else: else:
encoder_outputs = eval_model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = eval_model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
with torch.no_grad(): with torch.no_grad():
...@@ -1507,7 +1581,7 @@ def main(): ...@@ -1507,7 +1581,7 @@ def main():
def generate_step(batch): def generate_step(batch):
batch.pop("decoder_attention_mask", None) batch.pop("decoder_attention_mask", None)
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").eval() eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval()
if training_args.torch_compile: if training_args.torch_compile:
eval_model = model._orig_mod eval_model = model._orig_mod
...@@ -1518,7 +1592,7 @@ def main(): ...@@ -1518,7 +1592,7 @@ def main():
for epoch in range(epochs_trained, num_epochs): for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO(YL): add args # TODO(YL): add args
sampler = LengthGroupedSampler(train_batch_size, lengths = vectorized_datasets["train"]["target_length"]) sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader( train_dataloader = DataLoader(
vectorized_datasets["train"], vectorized_datasets["train"],
collate_fn=data_collator, collate_fn=data_collator,
...@@ -1546,7 +1620,6 @@ def main(): ...@@ -1546,7 +1620,6 @@ def main():
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
# Check if the accelerator has performed an optimization step behind the scenes # Check if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
steps_trained_progress_bar.update(1) steps_trained_progress_bar.update(1)
...@@ -1643,8 +1716,12 @@ def main(): ...@@ -1643,8 +1716,12 @@ def main():
# Gather all predictions and targets # Gather all predictions and targets
# TODO: also add prompt ids # TODO: also add prompt ids
# TODO: better gather # TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes((generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0) generated_audios, input_ids, prompts = accelerator.pad_across_processes(
generated_audios, input_ids, prompts = accelerator.gather_for_metrics((generated_audios, input_ids, prompts)) (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
)
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
(generated_audios, input_ids, prompts)
)
eval_preds.extend(generated_audios.to("cpu")) eval_preds.extend(generated_audios.to("cpu"))
eval_descriptions.extend(input_ids.to("cpu")) eval_descriptions.extend(input_ids.to("cpu"))
eval_prompts.extend(prompts.to("cpu")) eval_prompts.extend(prompts.to("cpu"))
...@@ -1652,7 +1729,8 @@ def main(): ...@@ -1652,7 +1729,8 @@ def main():
eval_time = time.time() - eval_start eval_time = time.time() - eval_start
# normalize eval metrics # normalize eval metrics
eval_metrics = { eval_metrics = {
key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics])) for key in eval_metrics[0] key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics]))
for key in eval_metrics[0]
} }
# compute metrics # compute metrics
...@@ -1697,7 +1775,6 @@ def main(): ...@@ -1697,7 +1775,6 @@ def main():
eval_prompts = [] eval_prompts = []
batch = release_memory(batch) batch = release_memory(batch)
# flush the train metrics # flush the train metrics
train_start = time.time() train_start = time.time()
...@@ -1712,7 +1789,6 @@ def main(): ...@@ -1712,7 +1789,6 @@ def main():
accelerator.end_training() accelerator.end_training()
if __name__ == "__main__": if __name__ == "__main__":
set_start_method("spawn") set_start_method("spawn")
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment