"test/srt/vscode:/vscode.git/clone" did not exist on "9fcf73069f30bbc75cd52b7a36ec961129f239cb"
Commit 31a54850 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 91542bfa
...@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -13,7 +13,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
ffn_dim=512, ffn_dim=512,
...@@ -27,34 +27,32 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -27,34 +27,32 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-model/") model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-model/")
\ No newline at end of file
...@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size ...@@ -20,7 +20,7 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=4, num_hidden_layers=4,
ffn_dim=512, ffn_dim=512,
...@@ -34,34 +34,32 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -34,34 +34,32 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-dac-model/") model.save_pretrained("/raid/yoach/tmp/artefacts/tiny-dac-model/")
\ No newline at end of file
...@@ -21,8 +21,8 @@ encodec_vocab_size = encodec.codebook_size ...@@ -21,8 +21,8 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=3000, # 30 s = 2580 max_position_embeddings=3000, # 30 s = 2580
num_hidden_layers=12, num_hidden_layers=12,
ffn_dim=4096, ffn_dim=4096,
num_attention_heads=16, num_attention_heads=16,
...@@ -35,33 +35,31 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -35,33 +35,31 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/small-stable-speech-untrained/") model.save_pretrained("/raid/yoach/tmp/artefacts/small-stable-speech-untrained/")
\ No newline at end of file
...@@ -21,8 +21,8 @@ encodec_vocab_size = encodec.codebook_size ...@@ -21,8 +21,8 @@ encodec_vocab_size = encodec.codebook_size
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size+1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=4096, # 30 s = 2580 max_position_embeddings=4096, # 30 s = 2580
num_hidden_layers=8, num_hidden_layers=8,
ffn_dim=3072, ffn_dim=3072,
num_attention_heads=12, num_attention_heads=12,
...@@ -35,33 +35,31 @@ decoder_config = ParlerTTSDecoderConfig( ...@@ -35,33 +35,31 @@ decoder_config = ParlerTTSDecoderConfig(
activation_dropout=0.0, activation_dropout=0.0,
pad_token_id=encodec_vocab_size, pad_token_id=encodec_vocab_size,
eos_token_id=encodec_vocab_size, eos_token_id=encodec_vocab_size,
bos_token_id=encodec_vocab_size+1, bos_token_id=encodec_vocab_size + 1,
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/") decoder.save_pretrained("/raid/yoach/tmp/artefacts/decoder_small/")
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/", decoder_pretrained_model_name_or_path="/raid/yoach/tmp/artefacts/decoder_small/",
vocab_size = t5.vocab_size vocab_size=t5.vocab_size,
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = encodec_vocab_size+1 model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
model.generation_config.pad_token_id = encodec_vocab_size model.generation_config.pad_token_id = encodec_vocab_size
model.generation_config.eos_token_id = encodec_vocab_size model.generation_config.eos_token_id = encodec_vocab_size
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/artefacts/stable-speech-untrained-75M/") model.save_pretrained("/raid/yoach/tmp/artefacts/stable-speech-untrained-75M/")
\ No newline at end of file
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
from .modeling_parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask from .modeling_parler_tts import (
ParlerTTSForCausalLM,
ParlerTTSForConditionalGeneration,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
from .dac_wrapper import DACConfig, DACModel from .dac_wrapper import DACConfig, DACModel
\ No newline at end of file
...@@ -81,7 +81,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -81,7 +81,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos) vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
......
from .configuration_dac import DACConfig from .configuration_dac import DACConfig
from .modeling_dac import DACModel from .modeling_dac import DACModel
\ No newline at end of file
...@@ -8,17 +8,16 @@ class DACConfig(PretrainedConfig): ...@@ -8,17 +8,16 @@ class DACConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
num_codebooks: int = 9, num_codebooks: int = 9,
model_bitrate: int = 8, # kbps model_bitrate: int = 8, # kbps
codebook_size: int = 1024, codebook_size: int = 1024,
latent_dim: int = 1024, latent_dim: int = 1024,
frame_rate: int = 86, frame_rate: int = 86,
**kwargs, **kwargs,
): ):
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.model_bitrate = model_bitrate self.model_bitrate = model_bitrate
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.num_codebooks = num_codebooks self.num_codebooks = num_codebooks
self.frame_rate = frame_rate self.frame_rate = frame_rate
super().__init__(**kwargs) super().__init__(**kwargs)
\ No newline at end of file
...@@ -7,22 +7,24 @@ from .configuration_dac import DACConfig ...@@ -7,22 +7,24 @@ from .configuration_dac import DACConfig
from dac.model import DAC from dac.model import DAC
# model doesn't support batching yet
# model doesn't support batching yet
class DACModel(PreTrainedModel): class DACModel(PreTrainedModel):
config_class = DACConfig config_class = DACConfig
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = DAC( self.model = DAC(
n_codebooks = config.num_codebooks, n_codebooks=config.num_codebooks,
latent_dim = config.latent_dim, latent_dim=config.latent_dim,
codebook_size = config.codebook_size, codebook_size=config.codebook_size,
) )
def encode(self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None): def encode(
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
):
""" """
Encodes the input audio waveform into discrete codes. Encodes the input audio waveform into discrete codes.
...@@ -44,7 +46,7 @@ class DACModel(PreTrainedModel): ...@@ -44,7 +46,7 @@ class DACModel(PreTrainedModel):
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`. `codebook` of shape `[batch_size, num_codebooks, frames]`.
Scale is not used here. Scale is not used here.
""" """
_, channels, input_length = input_values.shape _, channels, input_length = input_values.shape
...@@ -52,12 +54,12 @@ class DACModel(PreTrainedModel): ...@@ -52,12 +54,12 @@ class DACModel(PreTrainedModel):
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
audio_data = self.model.preprocess(input_values, sample_rate) audio_data = self.model.preprocess(input_values, sample_rate)
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
# TODO: for now, no chunk length # TODO: for now, no chunk length
chunk_length = None # self.config.chunk_length chunk_length = None # self.config.chunk_length
if chunk_length is None: if chunk_length is None:
chunk_length = input_length chunk_length = input_length
stride = input_length stride = input_length
...@@ -79,9 +81,9 @@ class DACModel(PreTrainedModel): ...@@ -79,9 +81,9 @@ class DACModel(PreTrainedModel):
for offset in range(0, input_length - step, stride): for offset in range(0, input_length - step, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool() mask = padding_mask[..., offset : offset + chunk_length].bool()
frame = audio_data[:, :, offset : offset + chunk_length] frame = audio_data[:, :, offset : offset + chunk_length]
scale = None scale = None
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers) _, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
encoded_frames.append(encoded_frame) encoded_frames.append(encoded_frame)
scales.append(scale) scales.append(scale)
...@@ -92,15 +94,14 @@ class DACModel(PreTrainedModel): ...@@ -92,15 +94,14 @@ class DACModel(PreTrainedModel):
return (encoded_frames, scales) return (encoded_frames, scales)
return EncodecEncoderOutput(encoded_frames, scales) return EncodecEncoderOutput(encoded_frames, scales)
def decode( def decode(
self, self,
audio_codes, audio_codes,
audio_scales, audio_scales,
padding_mask = None, padding_mask=None,
return_dict = None, return_dict=None,
): ):
""" """
Decodes the given frames into an output audio waveform. Decodes the given frames into an output audio waveform.
...@@ -125,12 +126,12 @@ class DACModel(PreTrainedModel): ...@@ -125,12 +126,12 @@ class DACModel(PreTrainedModel):
if len(audio_codes) != 1: if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}") raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0] audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
audio_values = self.model.decode(audio_values) audio_values = self.model.decode(audio_values)
if not return_dict: if not return_dict:
return (audio_values,) return (audio_values,)
return EncodecDecoderOutput(audio_values) return EncodecDecoderOutput(audio_values)
def forward(self, tensor): def forward(self, tensor):
raise ValueError(f"`DACModel.forward` not implemented yet") raise ValueError(f"`DACModel.forward` not implemented yet")
\ No newline at end of file
...@@ -46,7 +46,6 @@ from transformers.utils import ( ...@@ -46,7 +46,6 @@ from transformers.utils import (
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
...@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -60,6 +59,7 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts # See all ParlerTTS models at https://huggingface.co/models?filter=parler_tts
] ]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" the mask is set to -1, and otherwise setting to the value detailed in the mask."""
...@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): ...@@ -68,7 +68,10 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids return input_ids
def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int):
def build_delay_pattern_mask(
input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -91,9 +94,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1]) input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = ( input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is # we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1: if max_length < 2 * num_codebooks - 1:
...@@ -115,7 +116,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -115,7 +116,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
bos_mask = ~(bos_delay_pattern).to(input_ids.device) bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device) eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device) mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token # find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset) # and will always be in the first codebook (since it has no codebook offset)
...@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad ...@@ -132,6 +133,7 @@ def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask return input_ids, pattern_mask
@dataclass @dataclass
class ParlerTTSUnconditionalInput(ModelOutput): class ParlerTTSUnconditionalInput(ModelOutput):
""" """
...@@ -732,7 +734,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -732,7 +734,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
# TODO: not right dim # TODO: not right dim
embed_dim = config.vocab_size + 1 # + 1 for pad token id embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList( self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
) )
...@@ -762,8 +764,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -762,8 +764,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -799,11 +801,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -799,11 +801,11 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape # if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# As it is, the masked ids from the prompt will still count in the positions embeddings # As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None: if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
...@@ -812,11 +814,25 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -812,11 +814,25 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
) )
if past_key_values is None: if past_key_values is None:
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1) attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype),
],
dim=1,
)
else: else:
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1 generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat([prompt_attention_mask, torch.ones((input_shape[0] ,generated_length), device=self.device, dtype=prompt_attention_mask.dtype)], dim=1) attention_mask = torch.cat(
[
prompt_attention_mask,
torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
),
],
dim=1,
)
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
...@@ -957,8 +973,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -957,8 +973,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -1050,8 +1066,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1050,8 +1066,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -1098,26 +1114,26 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1098,26 +1114,26 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if labels is not None: if labels is not None:
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:] logits = lm_logits[:, :, -labels.shape[1] :]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks) # (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100) labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
# we use every codebooks token AND one single EOS at the end of each codebooks # we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1,2) != self.config.eos_token_id) & ((labels != -100)) mask = (input_ids.transpose(1, 2) != self.config.eos_token_id) & ((labels != -100))
# per codebook cross-entropy # per codebook cross-entropy
for codebook in range(self.config.num_codebooks): for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1]) codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1) codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1) codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask]) codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss loss += codebook_loss
loss = loss / self.config.num_codebooks loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
...@@ -1169,7 +1185,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1169,7 +1185,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
input_ids = input_ids.repeat((2, 1)) input_ids = input_ids.repeat((2, 1))
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1)) attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate( prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0 [prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
...@@ -1182,7 +1198,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1182,7 +1198,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask # we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None prompt_hidden_states = None
...@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1200,7 +1216,9 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
} }
# Ignore copy # Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None): def build_delay_pattern_mask(
self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
...@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1486,9 +1504,10 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)].reshape( output_ids = output_ids[
batch_size, self.decoder.num_codebooks, -1 (model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)
) & (model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)
].reshape(batch_size, self.decoder.num_codebooks, -1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_ids outputs.sequences = output_ids
...@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1520,9 +1539,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
"Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder." "Either a configuration has to be provided, or all three of text encoder, audio encoder and Parler-TTS decoder."
) )
if config is None: if config is None:
config = ParlerTTSConfig.from_sub_models_config( config = ParlerTTSConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
text_encoder.config, audio_encoder.config, decoder.config
)
else: else:
if not isinstance(config, self.config_class): if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}") raise ValueError(f"Config: {config} has to be of type {self.config_class}")
...@@ -1584,10 +1601,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1584,10 +1601,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
and self.decoder.config.cross_attention_hidden_size is None and self.decoder.config.cross_attention_hidden_size is None
): ):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings # prompt embeddings
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size) self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None: if self.text_encoder.get_output_embeddings() is not None:
raise ValueError( raise ValueError(
...@@ -1603,7 +1619,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1603,7 +1619,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly # Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self.post_init() self.post_init()
def _init_weights(self, module): def _init_weights(self, module):
std = self.decoder.config.initializer_factor std = self.decoder.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)): if isinstance(module, (nn.Linear, nn.Conv1d)):
...@@ -1885,9 +1901,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1885,9 +1901,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -1964,7 +1980,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1964,7 +1980,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
if prompt_hidden_states is None: if prompt_hidden_states is None:
if prompt_input_ids is not None: if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids) prompt_hidden_states = self.embed_prompts(prompt_input_ids)
...@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1974,7 +1990,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# TODO: verify it does what's expected # TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1,2) ).transpose(1, 2)
elif decoder_input_ids is None and decoder_inputs_embeds is None: elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder( audio_encoder_outputs = self.audio_encoder(
...@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2064,9 +2080,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
prompt_hidden_states = prompt_hidden_states.repeat((2,1,1)) prompt_hidden_states = prompt_hidden_states.repeat((2, 1, 1))
if prompt_attention_mask is not None: if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2,1)) prompt_attention_mask = prompt_attention_mask.repeat((2, 1))
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] past_length = past_key_values[0][0].shape[2]
...@@ -2079,11 +2095,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2079,11 +2095,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
remove_prefix_length = decoder_input_ids.shape[1] - 1 remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask # we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None prompt_hidden_states = None
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
...@@ -2191,7 +2206,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2191,7 +2206,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
return model_kwargs return model_kwargs
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs): def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids) model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids)
return model_kwargs return model_kwargs
...@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2244,7 +2259,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings # TODO: now it's possible with prompt_embeddings
...@@ -2281,13 +2296,13 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2281,13 +2296,13 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
batch_size = value.shape[0] batch_size = value.shape[0]
break break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def freeze_encoders(self, freeze_text_encoder=True): def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder: if freeze_text_encoder:
for param in self.text_encoder.parameters(): for param in self.text_encoder.parameters():
param.requires_grad = False param.requires_grad = False
self.text_encoder._requires_grad = False self.text_encoder._requires_grad = False
for param in self.audio_encoder.parameters(): for param in self.audio_encoder.parameters():
param.requires_grad = False param.requires_grad = False
self.audio_encoder._requires_grad = False self.audio_encoder._requires_grad = False
...@@ -2425,7 +2440,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2425,7 +2440,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
model_input_name, model_input_name,
guidance_scale=generation_config.guidance_scale, guidance_scale=generation_config.guidance_scale,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
# `prompt_hidden_states` are created and added to `model_kwargs` # `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs = self._prepare_prompt_kwargs_for_generation( model_kwargs = self._prepare_prompt_kwargs_for_generation(
...@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2586,7 +2601,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
_, mask = self.decoder.build_delay_pattern_mask( _, mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
...@@ -2594,11 +2608,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2594,11 +2608,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1], max_length=output_ids.shape[1],
) )
mask = (mask != generation_config.bos_token_id)&(mask != generation_config.pad_token_id) mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape( output_ids = output_ids[mask].reshape(batch_size, self.decoder.num_codebooks, -1)
batch_size, self.decoder.num_codebooks, -1
)
# append the frame dimension back to the audio codes # append the frame dimension back to the audio codes
output_ids = output_ids[None, ...] output_ids = output_ids[None, ...]
...@@ -2607,8 +2619,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2607,8 +2619,12 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if audio_scales is None: if audio_scales is None:
audio_scales = [None] * batch_size audio_scales = [None] * batch_size
decode_sequentially = generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids decode_sequentially = (
if not decode_sequentially: generation_config.bos_token_id in output_ids
or generation_config.pad_token_id in output_ids
or generation_config.eos_token_id in output_ids
)
if not decode_sequentially:
output_values = self.audio_encoder.decode( output_values = self.audio_encoder.decode(
output_ids, output_ids,
audio_scales=audio_scales, audio_scales=audio_scales,
...@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2617,16 +2633,19 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values = [] output_values = []
for sample_id in range(batch_size): for sample_id in range(batch_size):
sample = output_ids[:, sample_id] sample = output_ids[:, sample_id]
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0,1)) == 0) sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum()>0: if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask] sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2)) output_values.append(sample.transpose(0, 2))
else: else:
output_values.append(torch.zeros((1,1,1)).to(self.device)) output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).squeeze(-1).squeeze(-1) output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_values outputs.sequences = output_values
......
import dac import dac
# Download a model # Download a model
model_path = dac.utils.download(model_type="44khz") model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path) model = dac.DAC.load(model_path)
...@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig()) ...@@ -10,6 +10,7 @@ hf_dac = DACModel(DACConfig())
hf_dac.model.load_state_dict(model.state_dict()) hf_dac.model.load_state_dict(model.state_dict())
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -20,4 +21,4 @@ hf_dac.push_to_hub("ylacombe/dac_44khZ_8kbps") ...@@ -20,4 +21,4 @@ hf_dac.push_to_hub("ylacombe/dac_44khZ_8kbps")
from transformers import EncodecFeatureExtractor from transformers import EncodecFeatureExtractor
EncodecFeatureExtractor(sampling_rate=44100).push_to_hub("ylacombe/dac_44khZ_8kbps") EncodecFeatureExtractor(sampling_rate=44100).push_to_hub("ylacombe/dac_44khZ_8kbps")
\ No newline at end of file
...@@ -65,6 +65,7 @@ from transformers.integrations import is_wandb_available ...@@ -65,6 +65,7 @@ from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel
from parler_tts import DACConfig, DACModel from parler_tts import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig) AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel) AutoModel.register(DACConfig, DACModel)
...@@ -73,7 +74,12 @@ from accelerate import Accelerator ...@@ -73,7 +74,12 @@ from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig, apply_delay_pattern_mask, build_delay_pattern_mask from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
apply_delay_pattern_mask,
build_delay_pattern_mask,
)
if is_wandb_available(): if is_wandb_available():
from wandb import Audio from wandb import Audio
...@@ -90,8 +96,10 @@ logger = logging.getLogger(__name__) ...@@ -90,8 +96,10 @@ logger = logging.getLogger(__name__)
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$") _RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
def get_last_checkpoint(folder): def get_last_checkpoint(folder):
content = os.listdir(folder) content = os.listdir(folder)
checkpoints = [ checkpoints = [
...@@ -103,6 +111,7 @@ def get_last_checkpoint(folder): ...@@ -103,6 +111,7 @@ def get_last_checkpoint(folder):
return return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0]))) return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]: def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest.""" """Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = [] ordering_and_checkpoint_path = []
...@@ -118,6 +127,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[ ...@@ -118,6 +127,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted return checkpoints_sorted
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None: def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
"""Helper function to delete old checkpoints.""" """Helper function to delete old checkpoints."""
if save_total_limit is None or save_total_limit <= 0: if save_total_limit is None or save_total_limit <= 0:
...@@ -133,6 +143,7 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix ...@@ -133,6 +143,7 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True) shutil.rmtree(checkpoint, ignore_errors=True)
def log_metric( def log_metric(
accelerator, accelerator,
metrics: Dict, metrics: Dict,
...@@ -152,6 +163,7 @@ def log_metric( ...@@ -152,6 +163,7 @@ def log_metric(
log_metrics[f"{prefix}/learning_rate"] = learning_rate log_metrics[f"{prefix}/learning_rate"] = learning_rate
accelerator.log(log_metrics, step=step) accelerator.log(log_metrics, step=step)
def log_pred( def log_pred(
accelerator, accelerator,
pred_descriptions: List[str], pred_descriptions: List[str],
...@@ -180,24 +192,29 @@ def log_pred( ...@@ -180,24 +192,29 @@ def log_pred(
step=step, step=step,
commit=False, commit=False,
) )
# wandb can only loads 100 audios per step # wandb can only loads 100 audios per step
wandb_tracker.log({ wandb_tracker.log(
{
"Speech samples": [ "Speech samples": [
Audio( Audio(
audio, audio,
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}", caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
sample_rate=sampling_rate, sample_rate=sampling_rate,
) )
for (i, audio) in enumerate(audios[:min(len(audios), 100)]) for (i, audio) in enumerate(audios[: min(len(audios), 100)])
]}, ]
step=step) },
step=step,
)
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
""" """
# TODO: pretrain from scratch # TODO: pretrain from scratch
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
...@@ -212,7 +229,8 @@ class ModelArguments: ...@@ -212,7 +229,8 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained description tokenizer name or path if not the same as model_name"}
) )
prompt_tokenizer_name: Optional[str] = field( prompt_tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"} default=None,
metadata={"help": "Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
...@@ -251,10 +269,9 @@ class ModelArguments: ...@@ -251,10 +269,9 @@ class ModelArguments:
metadata={"help": "Generation max length."}, metadata={"help": "Generation max length."},
) )
bandwidth: float = field( bandwidth: float = field(
default=6, # TODO default=6, # TODO
metadata={"help": "Audio encoder bandwidth."}, metadata={"help": "Audio encoder bandwidth."},
) )
@dataclass @dataclass
...@@ -329,15 +346,15 @@ class DataTrainingArguments: ...@@ -329,15 +346,15 @@ class DataTrainingArguments:
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`." " librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
}, },
) )
target_audio_column_name: str = field( # TODO target_audio_column_name: str = field( # TODO
default="audio", default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
) )
description_column_name: str = field( #TODO description_column_name: str = field( # TODO
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
) )
prompt_column_name: str = field( #TODO prompt_column_name: str = field( # TODO
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
) )
...@@ -382,28 +399,31 @@ class DataTrainingArguments: ...@@ -382,28 +399,31 @@ class DataTrainingArguments:
default=500, metadata={"help": "If set, max description lengths in number of characters."} default=500, metadata={"help": "If set, max description lengths in number of characters."}
) )
max_prompt_token_length: int = field( max_prompt_token_length: int = field(
default=None, metadata={ default=None,
metadata={
"help": ( "help": (
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens." "If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`." "Also, used to set maximum prompt token length if `pad_to_max_length=True`."
) )
} },
) )
max_description_token_length: int = field( max_description_token_length: int = field(
default=None, metadata={ default=None,
metadata={
"help": ( "help": (
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens." "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`." "Also, used to set maximum desription token length if `pad_to_max_length=True`."
) )
} },
) )
pad_to_max_length: bool = field( pad_to_max_length: bool = field(
default=False, metadata={ default=False,
"help": ( metadata={
"If `True`, pad audio, prompt and description to a maximum length set with respectively " "help": (
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`." "If `True`, pad audio, prompt and description to a maximum length set with respectively "
) "`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
} )
},
) )
preprocessing_only: bool = field( preprocessing_only: bool = field(
default=False, default=False,
...@@ -444,16 +464,9 @@ class DataTrainingArguments: ...@@ -444,16 +464,9 @@ class DataTrainingArguments:
) )
add_audio_samples_to_wandb: bool = field( add_audio_samples_to_wandb: bool = field(
default=False, default=False,
metadata={ metadata={"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."},
"help": "If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."
}
)
id_column_name: str = field(
default=None,
metadata={
"help": "id column name."
}
) )
id_column_name: str = field(default=None, metadata={"help": "id column name."})
wandb_project: str = field( wandb_project: str = field(
default="parler-speech", default="parler-speech",
metadata={"help": "The name of the wandb project."}, metadata={"help": "The name of the wandb project."},
...@@ -462,23 +475,15 @@ class DataTrainingArguments: ...@@ -462,23 +475,15 @@ class DataTrainingArguments:
default=None, default=None,
metadata={ metadata={
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it." "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
} },
)
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
) )
temporary_save_to_disk: str = field(default=None, metadata={"help": "Temporarily save audio labels here."})
pad_to_multiple_of: Optional[int] = field( pad_to_multiple_of: Optional[int] = field(
default=2, default=2,
metadata={ metadata={"help": ("Pad to multiple of for tokenizers.")},
"help": (
"Pad to multiple of for tokenizers."
)
},
) )
@dataclass @dataclass
class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field( dtype: Optional[str] = field(
...@@ -492,17 +497,14 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): ...@@ -492,17 +497,14 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
) )
audio_encode_per_device_eval_batch_size: int = field( audio_encode_per_device_eval_batch_size: int = field(
default=8, default=8,
metadata={ metadata={"help": ("TODO")},
"help": (
"TODO"
)
},
) )
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
""" """
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
to `max_length` if `max_length` is set and `padding=max_length`. to `max_length` if `max_length` is set and `padding=max_length`.
""" """
...@@ -512,7 +514,6 @@ class DataCollatorEncodecWithPadding: ...@@ -512,7 +514,6 @@ class DataCollatorEncodecWithPadding:
max_length: Optional[int] = None max_length: Optional[int] = None
padding: Optional[str] = "longest" padding: Optional[str] = "longest"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
...@@ -523,7 +524,7 @@ class DataCollatorEncodecWithPadding: ...@@ -523,7 +524,7 @@ class DataCollatorEncodecWithPadding:
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1) batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
return batch return batch
@dataclass @dataclass
class DataCollatorParlerTTSWithPadding: class DataCollatorParlerTTSWithPadding:
""" """
...@@ -563,41 +564,55 @@ class DataCollatorParlerTTSWithPadding: ...@@ -563,41 +564,55 @@ class DataCollatorParlerTTSWithPadding:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
labels = [torch.tensor(feature["labels"]).transpose(0, 1) for feature in features]
labels = [torch.tensor(feature["labels"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks) # (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
if self.audio_max_length is not None and self.padding=="max_length": if self.audio_max_length is not None and self.padding == "max_length":
labels = torch.nn.functional.pad(labels, pad=(0,0,0,max(self.audio_max_length-labels.shape[1], 0))) labels = torch.nn.functional.pad(labels, pad=(0, 0, 0, max(self.audio_max_length - labels.shape[1], 0)))
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.description_max_length) input_ids = self.description_tokenizer.pad(
input_ids,
batch= {"labels":labels, **input_ids} return_tensors="pt",
padding=self.padding,
if self.audio_max_length is not None and self.padding=="max_length": pad_to_multiple_of=self.pad_to_multiple_of,
max_length=self.description_max_length,
)
batch = {"labels": labels, **input_ids}
if self.audio_max_length is not None and self.padding == "max_length":
# if we do torch.compile, we need to also specify the attention_mask # if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype) decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.prompt_max_length) prompt_input_ids = self.prompt_tokenizer.pad(
prompt_input_ids,
return_tensors="pt",
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
max_length=self.prompt_max_length,
)
batch["prompt_input_ids"] = prompt_input_ids["input_ids"] batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids: if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
if self.feature_extractor_input_name in features[0]: if self.feature_extractor_input_name in features[0]:
# TODO (YL): verify that it works - IMPORTANT -> probably not working # TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values = [{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features] input_values = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
input_values = self.feature_extractor.pad(input_values, return_tensors="pt") input_values = self.feature_extractor.pad(input_values, return_tensors="pt")
batch[self.feature_extractor_input_name: input_values] batch[self.feature_extractor_input_name : input_values]
return batch return batch
def convert_dataset_str_to_list( def convert_dataset_str_to_list(
dataset_names, dataset_names,
dataset_config_names, dataset_config_names,
...@@ -660,7 +675,7 @@ def load_multiple_datasets( ...@@ -660,7 +675,7 @@ def load_multiple_datasets(
accelerator: Accelerator, accelerator: Accelerator,
dataset_names: Union[List, str], dataset_names: Union[List, str],
dataset_config_names: Union[List, str], dataset_config_names: Union[List, str],
metadata_dataset_names: Optional[str]=None, metadata_dataset_names: Optional[str] = None,
splits: Optional[Union[List, str]] = None, splits: Optional[Union[List, str]] = None,
label_column_names: Optional[List] = None, label_column_names: Optional[List] = None,
stopping_strategy: Optional[str] = "first_exhausted", stopping_strategy: Optional[str] = "first_exhausted",
...@@ -696,16 +711,16 @@ def load_multiple_datasets( ...@@ -696,16 +711,16 @@ def load_multiple_datasets(
**kwargs, **kwargs,
) )
dataset_features = dataset.features.keys() dataset_features = dataset.features.keys()
if sampling_rate is not None and audio_column_name is not None: if sampling_rate is not None and audio_column_name is not None:
# resample target audio # resample target audio
dataset = dataset.cast_column( dataset = dataset.cast_column(audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate))
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
metadata_dataset_name = dataset_dict["metadata_dataset_name"] metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None: if metadata_dataset_name is not None:
logger.info(f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}') logger.info(
f'Merging {dataset_dict["name"]} - {dataset_dict["split"]} with {metadata_dataset_name} - {dataset_dict["split"]}'
)
metadata_dataset = load_dataset( metadata_dataset = load_dataset(
metadata_dataset_name, metadata_dataset_name,
dataset_dict["config"], dataset_dict["config"],
...@@ -713,7 +728,7 @@ def load_multiple_datasets( ...@@ -713,7 +728,7 @@ def load_multiple_datasets(
streaming=streaming, streaming=streaming,
**kwargs, **kwargs,
) )
# TODO(YL): I forgot to create unique ids for MLS english. # TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time # To iterate faster, I bypass the original id check and do another one. - Done once because assuming it won't change next time
# if dataset_dict["name"] == "parler-tts/mls_eng_10k": # if dataset_dict["name"] == "parler-tts/mls_eng_10k":
...@@ -728,32 +743,44 @@ def load_multiple_datasets( ...@@ -728,32 +743,44 @@ def load_multiple_datasets(
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns" f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}." f"- one of {', '.join(list(dataset.column_names))}."
) )
if id_column_name is not None and id_column_name not in metadata_dataset.column_names: if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns" f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}." f"- one of {', '.join(list(metadata_dataset.column_names))}."
) )
elif id_column_name is not None: elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
if prompt_column_name is not None: if prompt_column_name is not None:
# We might have applied some transformations to the prompts (e.g punctuation restoration) # We might have applied some transformations to the prompts (e.g punctuation restoration)
# so we make sure to remove it from the original dataset # so we make sure to remove it from the original dataset
if prompt_column_name in dataset.column_names: if prompt_column_name in dataset.column_names:
logger.info(f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']") logger.info(
f"REMOVE {prompt_column_name} from dataset {dataset_dict['name']} - dataset_dict['split']"
)
dataset.remove_columns(prompt_column_name) dataset.remove_columns(prompt_column_name)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove) metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k": if id_column_name is not None and dataset_dict["name"] != "parler-tts/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if (
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}") len(
dataset.filter(
lambda id1, id2: id1 != id2,
input_columns=[id_column_name, f"metadata_{id_column_name}"],
)
)
!= 0
):
raise ValueError(
f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}"
)
dataset_features = dataset.features.keys() dataset_features = dataset.features.keys()
...@@ -778,8 +805,7 @@ def load_multiple_datasets( ...@@ -778,8 +805,7 @@ def load_multiple_datasets(
return interleaved_dataset return interleaved_dataset
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
...@@ -796,16 +822,22 @@ def main(): ...@@ -796,16 +822,22 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_parler_tts", model_args, data_args) send_example_telemetry("run_parler_tts", model_args, data_args)
if training_args.dtype == "float16": if training_args.dtype == "float16":
mixed_precision = "fp16" mixed_precision = "fp16"
elif training_args.dtype == "bfloat16": elif training_args.dtype == "bfloat16":
mixed_precision = "bf16" mixed_precision = "bf16"
else: else:
mixed_precision = "no" mixed_precision = "no"
if data_args.pad_to_max_length and (data_args.max_duration_in_seconds is None or data_args.max_prompt_token_length is None or data_args.max_description_token_length is None): if data_args.pad_to_max_length and (
raise ValueError("`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`") data_args.max_duration_in_seconds is None
or data_args.max_prompt_token_length is None
or data_args.max_description_token_length is None
):
raise ValueError(
"`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
)
padding = "max_length" if data_args.pad_to_max_length else "longest" padding = "max_length" if data_args.pad_to_max_length else "longest"
...@@ -813,8 +845,8 @@ def main(): ...@@ -813,8 +845,8 @@ def main():
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile: if training_args.torch_compile:
# TODO(YL): add more compile modes? # TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) #reduce-overhead kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
...@@ -822,26 +854,28 @@ def main(): ...@@ -822,26 +854,28 @@ def main():
project_dir=training_args.output_dir, project_dir=training_args.output_dir,
kwargs_handlers=kwargs_handlers, kwargs_handlers=kwargs_handlers,
) )
accelerator.init_trackers(project_name=data_args.wandb_project, config={ accelerator.init_trackers(
"learning_rate": training_args.learning_rate, project_name=data_args.wandb_project,
"model_name_or_path": model_args.model_name_or_path, config={
"num_train_epochs": training_args.num_train_epochs, "learning_rate": training_args.learning_rate,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps, "model_name_or_path": model_args.model_name_or_path,
"per_device_train_batch_size": training_args.per_device_train_batch_size, "num_train_epochs": training_args.num_train_epochs,
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes, "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"mixed_precision": mixed_precision, "per_device_train_batch_size": training_args.per_device_train_batch_size,
"lr_scheduler_type":training_args.lr_scheduler_type, "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
"warmup_steps":training_args.warmup_steps, "mixed_precision": mixed_precision,
"freeze_text_encoder":model_args.freeze_text_encoder, "lr_scheduler_type": training_args.lr_scheduler_type,
"max_duration_in_seconds":data_args.max_duration_in_seconds, "warmup_steps": training_args.warmup_steps,
"weight_decay": training_args.weight_decay, "freeze_text_encoder": model_args.freeze_text_encoder,
"adam_beta1": training_args.adam_beta1, "max_duration_in_seconds": data_args.max_duration_in_seconds,
"adam_beta2": training_args.adam_beta2, "weight_decay": training_args.weight_decay,
"temperature": model_args.temperature, "adam_beta1": training_args.adam_beta1,
}) "adam_beta2": training_args.adam_beta2,
"temperature": model_args.temperature,
},
)
# Detecting last checkpoint and eventually continue from last checkpoint # Detecting last checkpoint and eventually continue from last checkpoint
last_checkpoint = None last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
...@@ -856,7 +890,6 @@ def main(): ...@@ -856,7 +890,6 @@ def main():
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
) )
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -880,17 +913,16 @@ def main(): ...@@ -880,17 +913,16 @@ def main():
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model. # Set seed before initializing model.
set_seed(training_args.seed) set_seed(training_args.seed)
num_workers = data_args.preprocessing_num_workers num_workers = data_args.preprocessing_num_workers
# 1. First, lett's instantiate the feature extractor, tokenizers and model # 1. First, lett's instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only # Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab. # one local process can concurrently download model & vocab.
# load feature extractor # load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path, model_args.feature_extractor_name or model_args.model_name_or_path,
...@@ -899,7 +931,7 @@ def main(): ...@@ -899,7 +931,7 @@ def main():
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
) )
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
# load prompt tokenizer # load prompt tokenizer
prompt_tokenizer = AutoTokenizer.from_pretrained( prompt_tokenizer = AutoTokenizer.from_pretrained(
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path, model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
...@@ -907,9 +939,9 @@ def main(): ...@@ -907,9 +939,9 @@ def main():
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
) )
# load description tokenizer # load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained( description_tokenizer = AutoTokenizer.from_pretrained(
model_args.description_tokenizer_name or model_args.model_name_or_path, model_args.description_tokenizer_name or model_args.model_name_or_path,
...@@ -918,31 +950,33 @@ def main(): ...@@ -918,31 +950,33 @@ def main():
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
) )
if model_args.use_fast_tokenizer: if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235") logger.warning(
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
)
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# 2. Now, let's load the dataset # 2. Now, let's load the dataset
if data_args.save_to_disk is not None: if data_args.save_to_disk is not None:
os.makedirs(data_args.save_to_disk, exist_ok=True) os.makedirs(data_args.save_to_disk, exist_ok=True)
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0 dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
if dataset_was_precomputed: if dataset_was_precomputed:
vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk) vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
else: else:
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
columns_to_keep = { columns_to_keep = {
"target_audio_column_name": data_args.target_audio_column_name, "target_audio_column_name": data_args.target_audio_column_name,
"prompt_column_name": data_args.prompt_column_name "prompt_column_name": data_args.prompt_column_name,
} }
if data_args.description_column_name is not None: if data_args.description_column_name is not None:
columns_to_keep["description_column_name"] = data_args.description_column_name columns_to_keep["description_column_name"] = data_args.description_column_name
if training_args.do_train: if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets( raw_datasets["train"] = load_multiple_datasets(
accelerator, accelerator,
...@@ -961,14 +995,14 @@ def main(): ...@@ -961,14 +995,14 @@ def main():
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
for key in columns_to_keep: for key in columns_to_keep:
if columns_to_keep[key] not in raw_datasets["train"].column_names: if columns_to_keep[key] not in raw_datasets["train"].column_names:
raise ValueError( raise ValueError(
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'." f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
f" Make sure to set `--{key}` to the correct audio column - one of" f" Make sure to set `--{key}` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}." f" {', '.join(raw_datasets['train'].column_names)}."
) )
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
...@@ -977,7 +1011,9 @@ def main(): ...@@ -977,7 +1011,9 @@ def main():
raw_datasets["eval"] = load_multiple_datasets( raw_datasets["eval"] = load_multiple_datasets(
accelerator, accelerator,
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name, data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name if data_args.eval_dataset_config_name else data_args.train_dataset_config_name, data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
metadata_dataset_names=data_args.eval_metadata_dataset_name, metadata_dataset_names=data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name, splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -991,8 +1027,9 @@ def main(): ...@@ -991,8 +1027,9 @@ def main():
) )
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# 3. Next, let's load the config. # 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch # TODO(YL): add the option to create the config from scratch
...@@ -1002,14 +1039,20 @@ def main(): ...@@ -1002,14 +1039,20 @@ def main():
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
) )
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder # TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config.update({ config.update(
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id, {
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id, "pad_token_id": model_args.pad_token_id
}) if model_args.pad_token_id is not None
else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None
else model.config.decoder_start_token_id,
}
)
# create model + TODO(YL): not from_pretrained probably # create model + TODO(YL): not from_pretrained probably
model = ParlerTTSForConditionalGeneration.from_pretrained( model = ParlerTTSForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
...@@ -1018,16 +1061,16 @@ def main(): ...@@ -1018,16 +1061,16 @@ def main():
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
) )
# enable gradient checkpointing if necessary # enable gradient checkpointing if necessary
if training_args.gradient_checkpointing: if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio, # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input # so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor` # via the `feature_extractor`
# derive max & min input length for sample rate & max duration # derive max & min input length for sample rate & max duration
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
max_target_length = data_args.max_duration_in_seconds * sampling_rate max_target_length = data_args.max_duration_in_seconds * sampling_rate
...@@ -1042,18 +1085,18 @@ def main(): ...@@ -1042,18 +1085,18 @@ def main():
max_length = model.generation_config.max_length max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth bandwidth = model_args.bandwidth
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing # TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout # Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device) test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor) gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor) print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if not dataset_was_precomputed: if not dataset_was_precomputed:
# Filter on text length # Filter on text length
if description_column_name is not None and data_args.max_text_length is not None: if description_column_name is not None and data_args.max_text_length is not None:
with accelerator.main_process_first(): with accelerator.main_process_first():
...@@ -1068,13 +1111,13 @@ def main(): ...@@ -1068,13 +1111,13 @@ def main():
# We need to tokenize the texts. # We need to tokenize the texts.
def pass_through_processors(description, prompt): def pass_through_processors(description, prompt):
batch = {} batch = {}
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"] batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
# TODO: add possibility to train without description column # TODO: add possibility to train without description column
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"] batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
return batch return batch
with accelerator.main_process_first(): with accelerator.main_process_first():
# this is a trick to avoid to rewrite the entire audio column which takes ages # this is a trick to avoid to rewrite the entire audio column which takes ages
vectorized_datasets = raw_datasets.map( vectorized_datasets = raw_datasets.map(
...@@ -1087,13 +1130,13 @@ def main(): ...@@ -1087,13 +1130,13 @@ def main():
# We use Accelerate to perform distributed inference # We use Accelerate to perform distributed inference
# T5 doesn't support fp16 # T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16")) autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Now we encode the audio labels with encodec. # Now we encode the audio labels with encodec.
####### B. Encode audio ####### B. Encode audio
logger.info("*** Encode target audio with encodec ***") logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision # no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
if training_args.torch_compile: if training_args.torch_compile:
...@@ -1101,7 +1144,13 @@ def main(): ...@@ -1101,7 +1144,13 @@ def main():
else: else:
audio_decoder = model.audio_encoder audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length,padding=padding) encoder_data_collator = DataCollatorEncodecWithPadding(
feature_extractor,
audio_column_name=target_audio_column_name,
feature_extractor_input_name=feature_extractor_input_name,
max_length=max_target_length,
padding=padding,
)
def apply_audio_decoder(batch): def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio") len_audio = batch.pop("len_audio")
...@@ -1111,8 +1160,8 @@ def main(): ...@@ -1111,8 +1160,8 @@ def main():
output = {} output = {}
output["len_audio"] = len_audio output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks) # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1,2) output["labels"] = labels.squeeze(0).transpose(1, 2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max() output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
return output return output
for split in vectorized_datasets: for split in vectorized_datasets:
...@@ -1123,82 +1172,83 @@ def main(): ...@@ -1123,82 +1172,83 @@ def main():
num_workers=training_args.dataloader_num_workers, num_workers=training_args.dataloader_num_workers,
pin_memory=True, pin_memory=True,
) )
data_loader = accelerator.prepare(data_loader) data_loader = accelerator.prepare(data_loader)
all_generated_labels = [] all_generated_labels = []
all_lens = [] all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch) generate_labels = apply_audio_decoder(batch)
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels) generate_labels = accelerator.gather_for_metrics(generate_labels)
if accelerator.is_main_process: if accelerator.is_main_process:
lab = generate_labels["labels"].cpu().transpose(1,2).to(torch.int16) lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze() rat = generate_labels["ratio"].cpu().squeeze()
lens = generate_labels["len_audio"].cpu().squeeze() lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, :int(ratio*length)] for (l, ratio, length) in zip(lab, rat, lens)] lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab) all_generated_labels.extend(lab)
all_lens.extend(lens) all_lens.extend(lens)
# (1, codebooks, seq_len) where seq_len=1 # (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
if accelerator.is_main_process: if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens}) tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split), num_proc=1 if split == "eval" else data_args.preprocessing_num_workers) tmp_labels.save_to_disk(
os.path.join(data_args.temporary_save_to_disk, split),
num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
del all_generated_labels del all_generated_labels
tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split)) tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels): def postprocess_dataset(labels):
# (1, codebooks, seq_len) # (1, codebooks, seq_len)
labels = torch.tensor(labels).unsqueeze(0) labels = torch.tensor(labels).unsqueeze(0)
# add bos # add bos
labels = torch.cat([bos_labels, labels], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(labels, labels, delay_pattern_mask = build_delay_pattern_mask(
bos_token_id=audio_encoder_bos_token_id, labels,
pad_token_id=audio_encoder_eos_token_id, bos_token_id=audio_encoder_bos_token_id,
max_length=labels.shape[-1] + num_codebooks, pad_token_id=audio_encoder_eos_token_id,
num_codebooks=num_codebooks) max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks,
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS # to take care of EOS
# we want labels to look like this: # we want labels to look like this:
# - [B, a, b, E, E, E, E] # - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E] # - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E] # - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E] # - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask==-1, audio_encoder_eos_token_id, delay_pattern_mask) labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD) # we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:]} output = {"labels": labels[:, 1:]}
return output return output
# TODO(YL): done multiple times, how to deal with it. # TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels"], input_columns=["labels"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
) )
accelerator.free_memory() accelerator.free_memory()
del generate_labels, all_lens del generate_labels, all_lens
with accelerator.main_process_first(): with accelerator.main_process_first():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets. # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets. # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
def is_audio_in_length_range(length): def is_audio_in_length_range(length):
...@@ -1210,7 +1260,7 @@ def main(): ...@@ -1210,7 +1260,7 @@ def main():
num_proc=num_workers, num_proc=num_workers,
input_columns=["target_length"], input_columns=["target_length"],
) )
if description_column_name is not None and data_args.max_description_token_length is not None: if description_column_name is not None and data_args.max_description_token_length is not None:
with accelerator.main_process_first(): with accelerator.main_process_first():
# filter description that is shorter than max_text_length # filter description that is shorter than max_text_length
...@@ -1228,22 +1278,24 @@ def main(): ...@@ -1228,22 +1278,24 @@ def main():
num_proc=num_workers, num_proc=num_workers,
input_columns=["prompt_input_ids"], input_columns=["prompt_input_ids"],
) )
if data_args.save_to_disk is not None and not dataset_was_precomputed: if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process: if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"])-1)) vectorized_datasets.save_to_disk(
data_args.save_to_disk,
num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
)
logger.info(f"Dataset saved at {data_args.save_to_disk}") logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None audio_max_length = None
if training_args.torch_compile: if training_args.torch_compile:
audio_max_length = max(vectorized_datasets["train"]["target_length"]) audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first(): with accelerator.main_process_first():
max_sample = vectorized_datasets["train"].filter( max_sample = vectorized_datasets["train"].filter(
lambda x: x == audio_max_length, lambda x: x == audio_max_length,
num_proc=num_workers, num_proc=num_workers,
input_columns=["target_length"], input_columns=["target_length"],
) )
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1] audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
# for large datasets it is advised to run the preprocessing on a # for large datasets it is advised to run the preprocessing on a
...@@ -1252,44 +1304,50 @@ def main(): ...@@ -1252,44 +1304,50 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset # cached dataset
if data_args.preprocessing_only and data_args.save_to_disk is None: if data_args.preprocessing_only and data_args.save_to_disk is None:
raise ValueError("`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally.") raise ValueError(
"`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
)
elif data_args.preprocessing_only: elif data_args.preprocessing_only:
logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}") logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
return return
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
# Let's use word CLAP similary and WER metrics as our evaluation metrics, # Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP # Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech") clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech")
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech") clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech")
metric = evaluate.load("wer") metric = evaluate.load("wer")
def clap_similarity(texts, audios, device): def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device) clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device) clap.to(device)
with torch.no_grad(): with torch.no_grad():
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)) text_features = clap.get_text_features(
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)
)
audio_features = clap.get_audio_features(clap_inputs["input_features"]) audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu") clap.to("cpu")
clap_inputs.to("cpu") clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu") return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device): def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device) asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios], batch_size=int(training_args.per_device_eval_batch_size)) transcriptions = asr_pipeline(
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios],
word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]) batch_size=int(training_args.per_device_eval_batch_size),
)
word_error = 100 * metric.compute(
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions] return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer} eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(audios, descriptions, prompts, device="cpu"): def compute_metrics(audios, descriptions, prompts, device="cpu"):
...@@ -1297,22 +1355,19 @@ def main(): ...@@ -1297,22 +1355,19 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.cpu().numpy() for a in audios]
results = { results = {"clap": eval_methods["clap"](texts, audios, device)}
"clap": eval_methods["clap"](texts, audios, device)
}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device) word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
results["wer"] = word_error results["wer"] = word_error
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions
# Define Training Schedule # Define Training Schedule
# Store some constants # Store some constants
per_device_train_batch_size = int(training_args.per_device_train_batch_size) per_device_train_batch_size = int(training_args.per_device_train_batch_size)
train_batch_size = per_device_train_batch_size * accelerator.num_processes train_batch_size = per_device_train_batch_size * accelerator.num_processes
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
if training_args.max_steps < 0: if training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs) num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps) steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
...@@ -1325,16 +1380,14 @@ def main(): ...@@ -1325,16 +1380,14 @@ def main():
steps_per_epoch = total_train_steps steps_per_epoch = total_train_steps
if training_args.eval_steps is None: if training_args.eval_steps is None:
logger.info( logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
f"eval_steps is not set, evaluating at the end of each epoch"
)
eval_steps = steps_per_epoch eval_steps = steps_per_epoch
else: else:
eval_steps = training_args.eval_steps eval_steps = training_args.eval_steps
# T5 doesn't support fp16 # T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16")) autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
# Define optimizer, LR scheduler, collator # Define optimizer, LR scheduler, collator
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
params=model.parameters(), params=model.parameters(),
...@@ -1354,14 +1407,20 @@ def main(): ...@@ -1354,14 +1407,20 @@ def main():
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorParlerTTSWithPadding( data_collator = DataCollatorParlerTTSWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of, audio_feature_extractor=feature_extractor,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length feature_extractor_input_name=feature_extractor_input_name,
prompt_tokenizer=prompt_tokenizer,
description_tokenizer=description_tokenizer,
pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding,
prompt_max_length=data_args.max_prompt_token_length,
description_max_length=data_args.max_description_token_length,
audio_max_length=audio_max_length,
) )
# Prepare everything with accelerate # Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}") logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}") logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
...@@ -1386,8 +1445,7 @@ def main(): ...@@ -1386,8 +1445,7 @@ def main():
checkpoint = training_args.resume_from_checkpoint checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None: elif last_checkpoint is not None:
checkpoint = last_checkpoint checkpoint = last_checkpoint
if accelerator.is_main_process: if accelerator.is_main_process:
if training_args.push_to_hub: if training_args.push_to_hub:
# Retrieve of infer repo_name # Retrieve of infer repo_name
...@@ -1405,23 +1463,27 @@ def main(): ...@@ -1405,23 +1463,27 @@ def main():
elif training_args.output_dir is not None: elif training_args.output_dir is not None:
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Now save everything to be able to create a single processor later # Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved # make sure all processes wait until data is saved
with accelerator.main_process_first(): with accelerator.main_process_first():
# only the main process saves them # only the main process saves them
if accelerator.is_main_process: if accelerator.is_main_process:
# save feature extractor, tokenizer and config # save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name): if (
model_args.prompt_tokenizer_name is None
and model_args.description_tokenizer_name
or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
):
prompt_tokenizer.save_pretrained(training_args.output_dir) prompt_tokenizer.save_pretrained(training_args.output_dir)
else: else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.") logger.warning(
"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
)
prompt_tokenizer.save_pretrained(training_args.output_dir) prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir) feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir)
if checkpoint is not None: if checkpoint is not None:
accelerator.load_state(checkpoint) accelerator.load_state(checkpoint)
...@@ -1439,7 +1501,7 @@ def main(): ...@@ -1439,7 +1501,7 @@ def main():
for epoch in range(0, epochs_trained): for epoch in range(0, epochs_trained):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
if training_args.max_steps < 0: if training_args.max_steps < 0:
# we know exactly the number of steps per epoch, so can skip through the required number of batches # we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
...@@ -1451,13 +1513,13 @@ def main(): ...@@ -1451,13 +1513,13 @@ def main():
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
else: else:
resume_step = None resume_step = None
gen_kwargs = { gen_kwargs = {
"do_sample": model_args.do_sample, "do_sample": model_args.do_sample,
"temperature": model_args.temperature, "temperature": model_args.temperature,
"max_length": model_args.max_length, "max_length": model_args.max_length,
} }
# Define gradient update step fn # Define gradient update step fn
def train_step( def train_step(
batch, batch,
...@@ -1465,26 +1527,34 @@ def main(): ...@@ -1465,26 +1527,34 @@ def main():
autocast_kwargs, autocast_kwargs,
): ):
model.train() model.train()
if mixed_precision == "fp16": if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models # fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed": if training_args.parallel_mode.value != "distributed":
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = model.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
else: else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
outputs = model(**batch) outputs = model(**batch)
# CE (data) loss # CE (data) loss
ce_loss = outputs.loss ce_loss = outputs.loss
# TODO: add CE per codebook # TODO: add CE per codebook
metrics = {"loss": ce_loss} metrics = {"loss": ce_loss}
return ce_loss, metrics return ce_loss, metrics
# Define eval fn # Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,): def eval_step(
batch,
accelerator,
autocast_kwargs,
):
eval_model = model if not training_args.torch_compile else model._orig_mod eval_model = model if not training_args.torch_compile else model._orig_mod
eval_model.eval() eval_model.eval()
...@@ -1493,9 +1563,13 @@ def main(): ...@@ -1493,9 +1563,13 @@ def main():
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
with torch.no_grad(): with torch.no_grad():
if training_args.parallel_mode.value != "distributed" or training_args.torch_compile: if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
encoder_outputs = eval_model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = eval_model.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
else: else:
encoder_outputs = eval_model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)) encoder_outputs = eval_model.module.text_encoder(
input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
)
batch["encoder_outputs"] = encoder_outputs batch["encoder_outputs"] = encoder_outputs
with torch.no_grad(): with torch.no_grad():
...@@ -1507,7 +1581,7 @@ def main(): ...@@ -1507,7 +1581,7 @@ def main():
def generate_step(batch): def generate_step(batch):
batch.pop("decoder_attention_mask", None) batch.pop("decoder_attention_mask", None)
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").eval() eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval()
if training_args.torch_compile: if training_args.torch_compile:
eval_model = model._orig_mod eval_model = model._orig_mod
...@@ -1518,7 +1592,7 @@ def main(): ...@@ -1518,7 +1592,7 @@ def main():
for epoch in range(epochs_trained, num_epochs): for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO(YL): add args # TODO(YL): add args
sampler = LengthGroupedSampler(train_batch_size, lengths = vectorized_datasets["train"]["target_length"]) sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader( train_dataloader = DataLoader(
vectorized_datasets["train"], vectorized_datasets["train"],
collate_fn=data_collator, collate_fn=data_collator,
...@@ -1546,7 +1620,6 @@ def main(): ...@@ -1546,7 +1620,6 @@ def main():
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
# Check if the accelerator has performed an optimization step behind the scenes # Check if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
steps_trained_progress_bar.update(1) steps_trained_progress_bar.update(1)
...@@ -1597,7 +1670,7 @@ def main(): ...@@ -1597,7 +1670,7 @@ def main():
eval_descriptions = [] eval_descriptions = []
eval_prompts = [] eval_prompts = []
eval_start = time.time() eval_start = time.time()
# release training input batch # release training input batch
batch = release_memory(batch) batch = release_memory(batch)
...@@ -1634,17 +1707,21 @@ def main(): ...@@ -1634,17 +1707,21 @@ def main():
validation_dataloader = accelerator.prepare(validation_dataloader) validation_dataloader = accelerator.prepare(validation_dataloader)
# generation # generation
for batch in tqdm( for batch in tqdm(
validation_dataloader, validation_dataloader,
desc=f"Evaluating - Generation ...", desc=f"Evaluating - Generation ...",
position=2, position=2,
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
): ):
generated_audios = generate_step(batch) generated_audios = generate_step(batch)
# Gather all predictions and targets # Gather all predictions and targets
# TODO: also add prompt ids # TODO: also add prompt ids
# TODO: better gather # TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes((generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0) generated_audios, input_ids, prompts = accelerator.pad_across_processes(
generated_audios, input_ids, prompts = accelerator.gather_for_metrics((generated_audios, input_ids, prompts)) (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
)
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
(generated_audios, input_ids, prompts)
)
eval_preds.extend(generated_audios.to("cpu")) eval_preds.extend(generated_audios.to("cpu"))
eval_descriptions.extend(input_ids.to("cpu")) eval_descriptions.extend(input_ids.to("cpu"))
eval_prompts.extend(prompts.to("cpu")) eval_prompts.extend(prompts.to("cpu"))
...@@ -1652,7 +1729,8 @@ def main(): ...@@ -1652,7 +1729,8 @@ def main():
eval_time = time.time() - eval_start eval_time = time.time() - eval_start
# normalize eval metrics # normalize eval metrics
eval_metrics = { eval_metrics = {
key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics])) for key in eval_metrics[0] key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics]))
for key in eval_metrics[0]
} }
# compute metrics # compute metrics
...@@ -1689,7 +1767,7 @@ def main(): ...@@ -1689,7 +1767,7 @@ def main():
epoch=epoch, epoch=epoch,
prefix="eval", prefix="eval",
) )
# release eval batch and relax metrics # release eval batch and relax metrics
eval_metrics = [] eval_metrics = []
eval_preds = [] eval_preds = []
...@@ -1697,7 +1775,6 @@ def main(): ...@@ -1697,7 +1775,6 @@ def main():
eval_prompts = [] eval_prompts = []
batch = release_memory(batch) batch = release_memory(batch)
# flush the train metrics # flush the train metrics
train_start = time.time() train_start = time.time()
...@@ -1710,9 +1787,8 @@ def main(): ...@@ -1710,9 +1787,8 @@ def main():
break break
accelerator.end_training() accelerator.end_training()
if __name__ == "__main__": if __name__ == "__main__":
set_start_method("spawn") set_start_method("spawn")
main() main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment