Commit 31a54850 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 91542bfa
......@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048,
num_hidden_layers=4,
ffn_dim=512,
......@@ -27,34 +27,32 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
# TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-model/")
\ No newline at end of file
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-model/")
......@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048,
num_hidden_layers=4,
ffn_dim=512,
......@@ -34,34 +34,32 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
# TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-dac-model/")
\ No newline at end of file
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-dac-model/")
......@@ -21,8 +21,8 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
max_position_embeddings=3000, # 30 s = 2580
vocab_size=encodec_vocab_size + 1,
max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12,
ffn_dim=4096,
num_attention_heads=16,
......@@ -35,33 +35,31 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/small-stable-speech-untrained/")
\ No newline at end of file
model.save_pretrained("/raid/yoach/tmp/artefacts/small-stable-speech-untrained/")
......@@ -21,8 +21,8 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1,
max_position_embeddings=4096, # 30 s = 2580
vocab_size=encodec_vocab_size + 1,
max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8,
ffn_dim=3072,
num_attention_heads=12,
......@@ -35,33 +35,31 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0,
pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1,
bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks,
)
decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
vocab_size = t5.vocab_size
vocab_size=t5.vocab_size,
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/stable-speech-untrained-75M/")
\ No newline at end of file
model.save_pretrained("/raid/yoach/tmp/artefacts/stable-speech-untrained-75M/")
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel
\ No newline at end of file
from .dac_wrapper import DACConfig, DACModel
......@@ -81,7 +81,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
def __init__(
self,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
max_position_embeddings=2048,
num_hidden_layers=24,
ffn_dim=4096,
......
from .configuration_dac import DACConfig
from .modeling_dac import DACModel
\ No newline at end of file
from .modeling_dac import DACModel
......@@ -8,17 +8,16 @@ class DACConfig(PretrainedConfig):
def __init__(
self,
num_codebooks: int = 9,
model_bitrate: int = 8, # kbps
model_bitrate: int = 8, # kbps
codebook_size: int = 1024,
latent_dim: int = 1024,
frame_rate: int = 86,
**kwargs,
):
self.codebook_size = codebook_size
self.model_bitrate = model_bitrate
self.latent_dim = latent_dim
self.num_codebooks = num_codebooks
self.frame_rate = frame_rate
super().__init__(**kwargs)
\ No newline at end of file
super().__init__(**kwargs)
......@@ -7,22 +7,24 @@ from .configuration_dac import DACConfig
from dac.model import DAC
# model doesn't support batching yet
# model doesn't support batching yet
class DACModel(PreTrainedModel):
config_class = DACConfig
def __init__(self, config):
super().__init__(config)
self.model = DAC(
n_codebooks = config.num_codebooks,
latent_dim = config.latent_dim,
codebook_size = config.codebook_size,
n_codebooks=config.num_codebooks,
latent_dim=config.latent_dim,
codebook_size=config.codebook_size,
)
def encode(self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None):
def encode(
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
):
"""
Encodes the input audio waveform into discrete codes.
......@@ -44,7 +46,7 @@ class DACModel(PreTrainedModel):
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`.
Scale is not used here.
"""
_, channels, input_length = input_values.shape
......@@ -52,12 +54,12 @@ class DACModel(PreTrainedModel):
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
audio_data = self.model.preprocess(input_values, sample_rate)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# TODO: for now, no chunk length
chunk_length = None # self.config.chunk_length
chunk_length = None # self.config.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
......@@ -79,9 +81,9 @@ class DACModel(PreTrainedModel):
for offset in range(0, input_length - step, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool()
frame = audio_data[:, :, offset : offset + chunk_length]
scale = None
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
encoded_frames.append(encoded_frame)
scales.append(scale)
......@@ -92,15 +94,14 @@ class DACModel(PreTrainedModel):
return (encoded_frames, scales)
return EncodecEncoderOutput(encoded_frames, scales)
def decode(
self,
audio_codes,
audio_scales,
padding_mask = None,
return_dict = None,
):
self,
audio_codes,
audio_scales,
padding_mask=None,
return_dict=None,
):
"""
Decodes the given frames into an output audio waveform.
......@@ -125,12 +126,12 @@ class DACModel(PreTrainedModel):
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
audio_values = self.model.decode(audio_values)
if not return_dict:
return (audio_values,)
return EncodecDecoderOutput(audio_values)
def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet")
\ No newline at end of file
raise ValueError(f"`DACModel.forward` not implemented yet")
......@@ -46,7 +46,6 @@ from transformers.utils import (
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
......@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
......@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int):
def build_delay_pattern_mask(
input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
......@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = (
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
......@@ -115,7 +116,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
......@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
@dataclass
class ParlerTTSUnconditionalInput(ModelOutput):
"""
......@@ -732,7 +734,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
# TODO: not right dim
embed_dim = config.vocab_size + 1 # + 1 for pad token id
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
)
......@@ -762,8 +764,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -799,11 +801,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
......@@ -812,11 +814,25 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if past_key_values is None:
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1)
attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype),
],
dim=1,
)
else:
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat([prompt_attention_mask, torch.ones((input_shape[0] ,generated_length), device=self.device, dtype=prompt_attention_mask.dtype)], dim=1)
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
),
],
dim=1,
)
input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......@@ -957,8 +973,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -1050,8 +1066,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -1098,26 +1114,26 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if labels is not None:
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:]
logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
# we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1,2) != self.config.eos_token_id) & ((labels != -100))
# we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100))
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
......@@ -1169,7 +1185,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
input_ids = input_ids.repeat((2, 1))
if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
......@@ -1182,7 +1198,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
......@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
}
# Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None):
def build_delay_pattern_mask(
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
......@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1
)
output_ids = output_ids[
(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)
& (model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)
].reshape(batch_size, self.decoder.num_codebooks, -1)
if generation_config.return_dict_in_generate:
outputs.sequences = output_ids
......@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
)
if config is None:
config = ParlerTTSConfig.from_sub_models_config(
text_encoder.config, audio_encoder.config, decoder.config
)
config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
......@@ -1584,10 +1601,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None:
raise ValueError(
......@@ -1603,7 +1619,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self.post_init()
def _init_weights(self, module):
std = self.decoder.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)):
......@@ -1885,9 +1901,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
......@@ -1964,7 +1980,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
......@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1,2)
).transpose(1, 2)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder(
......@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2,1,1))
prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2,1))
prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
......@@ -2079,11 +2095,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
......@@ -2191,7 +2206,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
return model_kwargs
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids)
return model_kwargs
......@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2)
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
......@@ -2281,13 +2296,13 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
batch_size = value.shape[0]
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
......@@ -2425,7 +2440,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
model_input_name,
guidance_scale=generation_config.guidance_scale,
)
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
# `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs = self._prepare_prompt_kwargs_for_generation(
......@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
......@@ -2594,11 +2608,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id)&(mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(
batch_size, self.decoder.num_codebooks, -1
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
# append the frame dimension back to the audio codes
output_ids = output_ids[None, ...]
......@@ -2607,8 +2619,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if audio_scales is None:
audio_scales = [None] * batch_size
decode_sequentially = generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids
if not decode_sequentially:
decode_sequentially = (
generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if not decode_sequentially:
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
......@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0,1)) == 0)
if sample_mask.sum()>0:
sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1,1,1)).to(self.device))
output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).squeeze(-1).squeeze(-1)
output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate:
outputs.sequences = output_values
......
import dac
# Download a model
model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path)
......@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict())
from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -20,4 +21,4 @@ hf_dac.push_to_hub("ylacombe/dac_44khZ_8kbps")
from transformers import EncodecFeatureExtractor
EncodecFeatureExtractor(sampling_rate=44100).push_to_hub("ylacombe/dac_44khZ_8kbps")
\ No newline at end of file
EncodecFeatureExtractor(sampling_rate=44100).push_to_hub("ylacombe/dac_44khZ_8kbps")
......@@ -65,6 +65,7 @@ from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel
from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -73,7 +74,12 @@ from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig, apply_delay_pattern_mask, build_delay_pattern_mask
from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
if is_wandb_available():
from wandb import Audio
......@@ -90,8 +96,10 @@ logger = logging.getLogger(__name__)
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
......@@ -103,6 +111,7 @@ def get_last_checkpoint(folder):
return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
......@@ -118,6 +127,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
"""Helper function to delete old checkpoints."""
if save_total_limit is None or save_total_limit <= 0:
......@@ -133,6 +143,7 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
def log_metric(
accelerator,
metrics: Dict,
......@@ -152,6 +163,7 @@ def log_metric(
log_metrics[f"{prefix}/learning_rate"] = learning_rate
accelerator.log(log_metrics, step=step)
def log_pred(
accelerator,
pred_descriptions: List[str],
......@@ -180,24 +192,29 @@ def log_pred(
step=step,
commit=False,
)
# wandb can only loads 100 audios per step
wandb_tracker.log({
wandb_tracker.log(
{
"Speech samples": [
Audio(
audio,
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
sample_rate=sampling_rate,
)
for (i, audio) in enumerate(audios[:min(len(audios), 100)])
]},
step=step)
for (i, audio) in enumerate(audios[: min(len(audios), 100)])
]
},
step=step,
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
# TODO: pretrain from scratch
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
......@@ -212,7 +229,8 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
)
prompt_tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"}
default=None,
metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
)
cache_dir: Optional[str] = field(
default=None,
......@@ -251,10 +269,9 @@ class ModelArguments:
metadata={"help": "Generation max length."},
)
bandwidth: float = field(
default=6, # TODO
default=6, # TODO
metadata={"help": "Audio encoder bandwidth."},
)
@dataclass
......@@ -329,15 +346,15 @@ class DataTrainingArguments:
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
target_audio_column_name: str = field( # TODO
target_audio_column_name: str = field( # TODO
default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
)
description_column_name: str = field( #TODO
description_column_name: str = field( # TODO
default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
)
prompt_column_name: str = field( #TODO
prompt_column_name: str = field( # TODO
default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
)
......@@ -382,28 +399,31 @@ class DataTrainingArguments:
default=500, metadata={"help": "If set, max description lengths in number of characters."}
)
max_prompt_token_length: int = field(
default=None, metadata={
default=None,
metadata={
"help": (
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
)
}
},
)
max_description_token_length: int = field(
default=None, metadata={
default=None,
metadata={
"help": (
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
)
}
},
)
pad_to_max_length: bool = field(
default=False, metadata={
"help": (
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
)
}
default=False,
metadata={
"help": (
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
)
},
)
preprocessing_only: bool = field(
default=False,
......@@ -444,16 +464,9 @@ class DataTrainingArguments:
)
add_audio_samples_to_wandb: bool = field(
default=False,
metadata={
"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."
}
)
id_column_name: str = field(
default=None,
metadata={
"help": "id column name."
}
metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
)
id_column_name: str = field(default=None, metadata={"help": "id column name."})
wandb_project: str = field(
default="parler-speech",
metadata={"help": "The name of the wandb project."},
......@@ -462,23 +475,15 @@ class DataTrainingArguments:
default=None,
metadata={
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
},
)
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
pad_to_multiple_of: Optional[int] = field(
default=2,
metadata={
"help": (
"Pad to multiple of for tokenizers."
)
},
metadata={"help": ("Pad to multiple of for tokenizers.")},
)
@dataclass
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field(
......@@ -492,17 +497,14 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
)
audio_encode_per_device_eval_batch_size: int = field(
default=8,
metadata={
"help": (
"TODO"
)
},
metadata={"help": ("TODO")},
)
@dataclass
class DataCollatorEncodecWithPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
to `max_length` if `max_length` is set and `padding=max_length`.
"""
......@@ -512,7 +514,6 @@ class DataCollatorEncodecWithPadding:
max_length: Optional[int] = None
padding: Optional[str] = "longest"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
......@@ -523,7 +524,7 @@ class DataCollatorEncodecWithPadding:
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
return batch
@dataclass
class DataCollatorParlerTTSWithPadding:
"""
......@@ -563,41 +564,55 @@ class DataCollatorParlerTTSWithPadding:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
labels = [torch.tensor(feature["labels"]).transpose(0,1) for feature in features]
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
# (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
if self.audio_max_length is not None and self.padding=="max_length":
labels = torch.nn.functional.pad(labels, pad=(0,0,0,max(self.audio_max_length-labels.shape[1], 0)))
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
if self.audio_max_length is not None and self.padding == "max_length":
labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)))
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.description_max_length)
batch= {"labels":labels, **input_ids}
if self.audio_max_length is not None and self.padding=="max_length":
input_ids = self.description_tokenizer.pad(
input_ids,
return_tensors="pt",
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
max_length=self.description_max_length,
)
batch = {"labels": labels, **input_ids}
if self.audio_max_length is not None and self.padding == "max_length":
# if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.prompt_max_length)
prompt_input_ids = self.prompt_tokenizer.pad(
prompt_input_ids,
return_tensors="pt",
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
max_length=self.prompt_max_length,
)
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
if self.feature_extractor_input_name in features[0]:
# TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values = [{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features]
input_values = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
input_values = self.feature_extractor.pad(input_values, return_tensors="pt")
batch[self.feature_extractor_input_name: input_values]
batch[self.feature_extractor_input_name : input_values]
return batch
def convert_dataset_str_to_list(
dataset_names,
dataset_config_names,
......@@ -660,7 +675,7 @@ def load_multiple_datasets(
accelerator: Accelerator,
dataset_names: Union[List, str],
dataset_config_names: Union[List, str],
metadata_dataset_names: Optional[str]=None,
metadata_dataset_names: Optional[str] = None,
splits: Optional[Union[List, str]] = None,
label_column_names: Optional[List] = None,
stopping_strategy: Optional[str] = "first_exhausted",
......@@ -696,16 +711,16 @@ def load_multiple_datasets(
**kwargs,
)
dataset_features = dataset.features.keys()
if sampling_rate is not None and audio_column_name is not None:
# resample target audio
dataset = dataset.cast_column(
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
logger.info(f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}')
logger.info(
f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}'
)
metadata_dataset = load_dataset(
metadata_dataset_name,
dataset_dict["config"],
......@@ -713,7 +728,7 @@ def load_multiple_datasets(
streaming=streaming,
**kwargs,
)
# TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "parler-tts/mls_eng_10k":
......@@ -728,32 +743,44 @@ def load_multiple_datasets(
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
if prompt_column_name is not None:
# We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names:
logger.info(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']")
logger.info(
f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']"
)
dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
if (
len(
dataset.filter(
lambda id1, id2: id1 != id2,
input_columns=[id_column_name, f"metadata_{id_column_name}"],
)
)
!= 0
):
raise ValueError(
f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}"
)
dataset_features = dataset.features.keys()
......@@ -778,8 +805,7 @@ def load_multiple_datasets(
return interleaved_dataset
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
......@@ -796,16 +822,22 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_parler_tts", model_args, data_args)
if training_args.dtype == "float16":
mixed_precision = "fp16"
elif training_args.dtype == "bfloat16":
mixed_precision = "bf16"
else:
mixed_precision = "no"
if data_args.pad_to_max_length and (data_args.max_duration_in_seconds is None or data_args.max_prompt_token_length is None or data_args.max_description_token_length is None):
raise ValueError("`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`")
if data_args.pad_to_max_length and (
data_args.max_duration_in_seconds is None
or data_args.max_prompt_token_length is None
or data_args.max_description_token_length is None
):
raise ValueError(
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
)
padding = "max_length" if data_args.pad_to_max_length else "longest"
......@@ -813,8 +845,8 @@ def main():
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile:
# TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) #reduce-overhead
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
......@@ -822,26 +854,28 @@ def main():
project_dir=training_args.output_dir,
kwargs_handlers=kwargs_handlers,
)
accelerator.init_trackers(project_name=data_args.wandb_project, config={
"learning_rate": training_args.learning_rate,
"model_name_or_path": model_args.model_name_or_path,
"num_train_epochs": training_args.num_train_epochs,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
"mixed_precision": mixed_precision,
"lr_scheduler_type":training_args.lr_scheduler_type,
"warmup_steps":training_args.warmup_steps,
"freeze_text_encoder":model_args.freeze_text_encoder,
"max_duration_in_seconds":data_args.max_duration_in_seconds,
"weight_decay": training_args.weight_decay,
"adam_beta1": training_args.adam_beta1,
"adam_beta2": training_args.adam_beta2,
"temperature": model_args.temperature,
})
accelerator.init_trackers(
project_name=data_args.wandb_project,
config={
"learning_rate": training_args.learning_rate,
"model_name_or_path": model_args.model_name_or_path,
"num_train_epochs": training_args.num_train_epochs,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
"mixed_precision": mixed_precision,
"lr_scheduler_type": training_args.lr_scheduler_type,
"warmup_steps": training_args.warmup_steps,
"freeze_text_encoder": model_args.freeze_text_encoder,
"max_duration_in_seconds": data_args.max_duration_in_seconds,
"weight_decay": training_args.weight_decay,
"adam_beta1": training_args.adam_beta1,
"adam_beta2": training_args.adam_beta2,
"temperature": model_args.temperature,
},
)
# Detecting last checkpoint and eventually continue from last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
......@@ -856,7 +890,6 @@ def main():
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
......@@ -880,17 +913,16 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
num_workers = data_args.preprocessing_num_workers
# 1. First, lett's instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
# load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
......@@ -899,7 +931,7 @@ def main():
trust_remote_code=data_args.trust_remote_code,
)
sampling_rate = feature_extractor.sampling_rate
# load prompt tokenizer
prompt_tokenizer = AutoTokenizer.from_pretrained(
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
......@@ -907,9 +939,9 @@ def main():
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
)
# load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained(
model_args.description_tokenizer_name or model_args.model_name_or_path,
......@@ -918,31 +950,33 @@ def main():
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
)
if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235")
logger.warning(
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
)
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# 2. Now, let's load the dataset
if data_args.save_to_disk is not None:
os.makedirs(data_args.save_to_disk, exist_ok=True)
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
if dataset_was_precomputed:
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
else:
else:
raw_datasets = DatasetDict()
columns_to_keep = {
"target_audio_column_name": data_args.target_audio_column_name,
"prompt_column_name": data_args.prompt_column_name
"prompt_column_name": data_args.prompt_column_name,
}
if data_args.description_column_name is not None:
columns_to_keep["description_column_name"] = data_args.description_column_name
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
accelerator,
......@@ -961,14 +995,14 @@ def main():
sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
for key in columns_to_keep:
if columns_to_keep[key] not in raw_datasets["train"].column_names:
raise ValueError(
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
f" Make sure to set `--{key}` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}."
)
)
if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
......@@ -977,7 +1011,9 @@ def main():
raw_datasets["eval"] = load_multiple_datasets(
accelerator,
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name if data_args.eval_dataset_config_name else data_args.train_dataset_config_name,
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
metadata_dataset_names=data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir,
......@@ -991,8 +1027,9 @@ def main():
)
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch
......@@ -1002,14 +1039,20 @@ def main():
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
)
# update pad token id and decoder_start_token_id
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config.update({
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id,
})
config.update(
{
"pad_token_id": model_args.pad_token_id
if model_args.pad_token_id is not None
else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None
else model.config.decoder_start_token_id,
}
)
# create model + TODO(YL): not from_pretrained probably
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
......@@ -1018,16 +1061,16 @@ def main():
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
)
# enable gradient checkpointing if necessary
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# derive max & min input length for sample rate & max duration
sampling_rate = feature_extractor.sampling_rate
max_target_length = data_args.max_duration_in_seconds * sampling_rate
......@@ -1042,18 +1085,18 @@ def main():
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed:
if not dataset_was_precomputed:
# Filter on text length
if description_column_name is not None and data_args.max_text_length is not None:
with accelerator.main_process_first():
......@@ -1068,13 +1111,13 @@ def main():
# We need to tokenize the texts.
def pass_through_processors(description, prompt):
batch = {}
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
# TODO: add possibility to train without description column
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
return batch
with accelerator.main_process_first():
# this is a trick to avoid to rewrite the entire audio column which takes ages
vectorized_datasets = raw_datasets.map(
......@@ -1087,13 +1130,13 @@ def main():
# We use Accelerate to perform distributed inference
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Now we encode the audio labels with encodec.
####### B. Encode audio
logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
if training_args.torch_compile:
......@@ -1101,7 +1144,13 @@ def main():
else:
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length,padding=padding)
encoder_data_collator = DataCollatorEncodecWithPadding(
feature_extractor,
audio_column_name=target_audio_column_name,
feature_extractor_input_name=feature_extractor_input_name,
max_length=max_target_length,
padding=padding,
)
def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio")
......@@ -1111,8 +1160,8 @@ def main():
output = {}
output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1,2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
output["labels"] = labels.squeeze(0).transpose(1, 2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
return output
for split in vectorized_datasets:
......@@ -1123,82 +1172,83 @@ def main():
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
data_loader = accelerator.prepare(data_loader)
all_generated_labels = []
all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch)
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels)
if accelerator.is_main_process:
lab = generate_labels["labels"].cpu().transpose(1,2).to(torch.int16)
lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze()
lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, :int(ratio*length)] for (l, ratio, length) in zip(lab, rat, lens)]
lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab)
all_lens.extend(lens)
# (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split), num_proc=1 if split == "eval" else data_args.preprocessing_num_workers)
tmp_labels.save_to_disk(
os.path.join(data_args.temporary_save_to_disk, split),
num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
)
accelerator.wait_for_everyone()
del all_generated_labels
tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels):
# (1, codebooks, seq_len)
labels = torch.tensor(labels).unsqueeze(0)
labels = torch.tensor(labels).unsqueeze(0)
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(labels,
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks)
labels, delay_pattern_mask = build_delay_pattern_mask(
labels,
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask==-1, audio_encoder_eos_token_id, delay_pattern_mask)
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:]}
return output
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels"],
desc="Postprocessing labeling",
)
accelerator.free_memory()
del generate_labels, all_lens
with accelerator.main_process_first():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
def is_audio_in_length_range(length):
......@@ -1210,7 +1260,7 @@ def main():
num_proc=num_workers,
input_columns=["target_length"],
)
if description_column_name is not None and data_args.max_description_token_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
......@@ -1228,22 +1278,24 @@ def main():
num_proc=num_workers,
input_columns=["prompt_input_ids"],
)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"])-1))
vectorized_datasets.save_to_disk(
data_args.save_to_disk,
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
)
logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None
if training_args.torch_compile:
audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first():
with accelerator.main_process_first():
max_sample = vectorized_datasets["train"].filter(
lambda x: x == audio_max_length,
num_proc=num_workers,
input_columns=["target_length"],
)
lambda x: x == audio_max_length,
num_proc=num_workers,
input_columns=["target_length"],
)
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
# for large datasets it is advised to run the preprocessing on a
......@@ -1252,44 +1304,50 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only and data_args.save_to_disk is None:
raise ValueError("`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally.")
raise ValueError(
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
)
elif data_args.preprocessing_only:
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
return
# 6. Next, we can prepare the training.
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech")
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech")
metric = evaluate.load("wer")
def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
with torch.no_grad():
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None))
text_features = clap.get_text_features(
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
)
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios], batch_size=int(training_args.per_device_eval_batch_size))
word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts])
transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
batch_size=int(training_args.per_device_eval_batch_size),
)
word_error = 100 * metric.compute(
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(audios, descriptions, prompts, device="cpu"):
......@@ -1297,22 +1355,19 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
results = {
"clap": eval_methods["clap"](texts, audios, device)
}
results = {"clap": eval_methods["clap"](texts, audios, device)}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
results["wer"] = word_error
return results, texts, prompts, audios, transcriptions
# Define Training Schedule
# Store some constants
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
train_batch_size = per_device_train_batch_size * accelerator.num_processes
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
if training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
......@@ -1325,16 +1380,14 @@ def main():
steps_per_epoch = total_train_steps
if training_args.eval_steps is None:
logger.info(
f"eval_steps is not set, evaluating at the end of each epoch"
)
logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
eval_steps = steps_per_epoch
else:
eval_steps = training_args.eval_steps
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Define optimizer, LR scheduler, collator
optimizer = torch.optim.AdamW(
params=model.parameters(),
......@@ -1354,14 +1407,20 @@ def main():
# Instantiate custom data collator
data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length
audio_feature_extractor=feature_extractor,
feature_extractor_input_name=feature_extractor_input_name,
prompt_tokenizer=prompt_tokenizer,
description_tokenizer=description_tokenizer,
pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding,
prompt_max_length=data_args.max_prompt_token_length,
description_max_length=data_args.max_description_token_length,
audio_max_length=audio_max_length,
)
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
......@@ -1386,8 +1445,7 @@ def main():
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
if accelerator.is_main_process:
if training_args.push_to_hub:
# Retrieve of infer repo_name
......@@ -1405,23 +1463,27 @@ def main():
elif training_args.output_dir is not None:
os.makedirs(training_args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
if (
model_args.prompt_tokenizer_name is None
and model_args.description_tokenizer_name
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
logger.warning(
"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
)
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
if checkpoint is not None:
accelerator.load_state(checkpoint)
......@@ -1439,7 +1501,7 @@ def main():
for epoch in range(0, epochs_trained):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
if training_args.max_steps < 0:
# we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
......@@ -1451,13 +1513,13 @@ def main():
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
else:
resume_step = None
gen_kwargs = {
"do_sample": model_args.do_sample,
"temperature": model_args.temperature,
"max_length": model_args.max_length,
}
# Define gradient update step fn
def train_step(
batch,
......@@ -1465,26 +1527,34 @@ def main():
autocast_kwargs,
):
model.train()
if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed":
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
encoder_outputs = model.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
encoder_outputs = model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
batch["encoder_outputs"] = encoder_outputs
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
# TODO: add CE per codebook
# TODO: add CE per codebook
metrics = {"loss": ce_loss}
return ce_loss, metrics
# Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,):
def eval_step(
batch,
accelerator,
autocast_kwargs,
):
eval_model = model if not training_args.torch_compile else model._orig_mod
eval_model.eval()
......@@ -1493,9 +1563,13 @@ def main():
with accelerator.autocast(autocast_handler=autocast_kwargs):
with torch.no_grad():
if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
encoder_outputs = eval_model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
encoder_outputs = eval_model.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
else:
encoder_outputs = eval_model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
encoder_outputs = eval_model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
batch["encoder_outputs"] = encoder_outputs
with torch.no_grad():
......@@ -1507,7 +1581,7 @@ def main():
def generate_step(batch):
batch.pop("decoder_attention_mask", None)
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").eval()
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval()
if training_args.torch_compile:
eval_model = model._orig_mod
......@@ -1518,7 +1592,7 @@ def main():
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO(YL): add args
sampler = LengthGroupedSampler(train_batch_size, lengths = vectorized_datasets["train"]["target_length"])
sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader(
vectorized_datasets["train"],
collate_fn=data_collator,
......@@ -1546,7 +1620,6 @@ def main():
lr_scheduler.step()
optimizer.zero_grad()
# Check if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
steps_trained_progress_bar.update(1)
......@@ -1597,7 +1670,7 @@ def main():
eval_descriptions = []
eval_prompts = []
eval_start = time.time()
# release training input batch
batch = release_memory(batch)
......@@ -1634,17 +1707,21 @@ def main():
validation_dataloader = accelerator.prepare(validation_dataloader)
# generation
for batch in tqdm(
validation_dataloader,
desc=f"Evaluating - Generation ...",
position=2,
disable=not accelerator.is_local_main_process,
):
validation_dataloader,
desc=f"Evaluating - Generation ...",
position=2,
disable=not accelerator.is_local_main_process,
):
generated_audios = generate_step(batch)
# Gather all predictions and targets
# TODO: also add prompt ids
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes((generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0)
generated_audios, input_ids, prompts = accelerator.gather_for_metrics((generated_audios, input_ids, prompts))
generated_audios, input_ids, prompts = accelerator.pad_across_processes(
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
)
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
(generated_audios, input_ids, prompts)
)
eval_preds.extend(generated_audios.to("cpu"))
eval_descriptions.extend(input_ids.to("cpu"))
eval_prompts.extend(prompts.to("cpu"))
......@@ -1652,7 +1729,8 @@ def main():
eval_time = time.time() - eval_start
# normalize eval metrics
eval_metrics = {
key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics])) for key in eval_metrics[0]
key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics]))
for key in eval_metrics[0]
}
# compute metrics
......@@ -1689,7 +1767,7 @@ def main():
epoch=epoch,
prefix="eval",
)
# release eval batch and relax metrics
eval_metrics = []
eval_preds = []
......@@ -1697,7 +1775,6 @@ def main():
eval_prompts = []
batch = release_memory(batch)
# flush the train metrics
train_start = time.time()
......@@ -1710,9 +1787,8 @@ def main():
break
accelerator.end_training()
if __name__ == "__main__":
set_start_method("spawn")
main()
\ No newline at end of file
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment