Commit 11fcc066 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

working training + generation

parent 997bf5e6
...@@ -36,14 +36,14 @@ ...@@ -36,14 +36,14 @@
"preprocessing_num_workers": 1, "preprocessing_num_workers": 1,
"pad_token_id": 2048, "pad_token_id": 2049,
"decoder_start_token_id": 2048, "decoder_start_token_id": 2048,
"do_train": true, "do_train": true,
"num_train_epochs": 20, "num_train_epochs": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"gradient_checkpointing": true, "gradient_checkpointing": true,
"per_device_train_batch_size": 16, "per_device_train_batch_size": 8,
"learning_rate": 1e-6, "learning_rate": 1e-6,
"adam_beta1": 0.9, "adam_beta1": 0.9,
"adam_beta2": 0.95, "adam_beta2": 0.95,
...@@ -56,7 +56,7 @@ ...@@ -56,7 +56,7 @@
"predict_with_generate": true, "predict_with_generate": true,
"include_inputs_for_metrics": true, "include_inputs_for_metrics": true,
"evaluation_strategy": "epoch", "evaluation_strategy": "epoch",
"per_device_eval_batch_size": 16, "per_device_eval_batch_size": 8,
"generation_max_length": 400, "generation_max_length": 400,
"fp16": true, "fp16": true,
......
...@@ -35,7 +35,8 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( ...@@ -35,7 +35,8 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048 model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2048 model.generation_config.pad_token_id = 2049
model.generation_config.eos_token_id = 2049
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
......
...@@ -77,16 +77,10 @@ def list_field(default=None, metadata=None): ...@@ -77,16 +77,10 @@ def list_field(default=None, metadata=None):
class StableSpeechTrainer(Seq2SeqTrainer): class StableSpeechTrainer(Seq2SeqTrainer):
def _pad_tensors_to_max_len(self, tensor, max_length): def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): if self.model.config.pad_token_id is not None:
# If PAD token is not defined at least EOS token has to be defined pad_token_id = self.model.config.pad_token_id
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else: else:
if self.model.config.pad_token_id is not None: raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones( padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length, tensor.shape[2]), dtype=tensor.dtype, device=tensor.device (tensor.shape[0], max_length, tensor.shape[2]), dtype=tensor.dtype, device=tensor.device
...@@ -387,6 +381,7 @@ class DataCollatorStableSpeechWithPadding: ...@@ -387,6 +381,7 @@ class DataCollatorStableSpeechWithPadding:
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
# TODO: check it's been padded on the left
batch["prompt_input_ids"] = prompt_input_ids["input_ids"] batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids: if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
...@@ -676,6 +671,7 @@ def main(): ...@@ -676,6 +671,7 @@ def main():
) )
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
# TODO: verify if this makes sense, maybe should do it for model.decoder
config.update({ config.update({
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id, "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id, "decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id,
...@@ -700,6 +696,7 @@ def main(): ...@@ -700,6 +696,7 @@ def main():
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
) )
# load description tokenizer # load description tokenizer
...@@ -740,6 +737,10 @@ def main(): ...@@ -740,6 +737,10 @@ def main():
description_column_name = data_args.description_column_name description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0] feature_extractor_input_name = feature_extractor.model_input_names[0]
audio_encoder_eos_token_id = config.decoder.pad_token_id
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
# resample target audio # resample target audio
raw_datasets = raw_datasets.cast_column( raw_datasets = raw_datasets.cast_column(
...@@ -794,7 +795,6 @@ def main(): ...@@ -794,7 +795,6 @@ def main():
# no need to prepare audio_decoder because used for inference without mixed precision # no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder = model.audio_encoder audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, feature_extractor_input_name) encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, feature_extractor_input_name)
...@@ -832,18 +832,28 @@ def main(): ...@@ -832,18 +832,28 @@ def main():
all_ratios.extend(generate_labels["ratio"].cpu()) all_ratios.extend(generate_labels["ratio"].cpu())
all_lens.extend(generate_labels["len_audio"].cpu()) all_lens.extend(generate_labels["len_audio"].cpu())
# (1, codebooks, seq_len) where seq_len=1
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
def postprocess_dataset(sample, idx): def postprocess_dataset(sample, idx):
# (1, seq_len, codebooks, bsz) # (1, codebooks, seq_len)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0) labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0)
len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_]
# add eos token column
labels = torch.cat([labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels, labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
model.generation_config.decoder_start_token_id, audio_encoder_bos_token_id,
model.generation_config.max_length + model.decoder.config.num_codebooks) audio_encoder_eos_token_id,
max_length + num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask) labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
len_ = int(all_ratios[idx] * all_lens[idx])
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:len_] sample["labels"] = labels[:, 1:]
return sample return sample
# TODO: done multiple times, how to deal with it. # TODO: done multiple times, how to deal with it.
...@@ -956,7 +966,7 @@ def main(): ...@@ -956,7 +966,7 @@ def main():
"""Custom WandbCallback to log model predictions during training. """Custom WandbCallback to log model predictions during training.
""" """
def __init__(self, trainer, val_dataset, def __init__(self, trainer, val_dataset, description_tokenizer, # TODO: add
num_samples=8): num_samples=8):
"""Initializes the WandbPredictionProgressCallback instance. """Initializes the WandbPredictionProgressCallback instance.
...@@ -969,6 +979,7 @@ def main(): ...@@ -969,6 +979,7 @@ def main():
""" """
super().__init__() super().__init__()
self.trainer = trainer self.trainer = trainer
self.description_tokenizer = description_tokenizer
self.sample_dataset = val_dataset.select(range(num_samples)) self.sample_dataset = val_dataset.select(range(num_samples))
def on_train_end(self, args, state, control, **kwargs): def on_train_end(self, args, state, control, **kwargs):
...@@ -992,6 +1003,7 @@ def main(): ...@@ -992,6 +1003,7 @@ def main():
progress_callback = WandbPredictionProgressCallback( progress_callback = WandbPredictionProgressCallback(
trainer=trainer, trainer=trainer,
val_dataset=vectorized_datasets["eval"], val_dataset=vectorized_datasets["eval"],
description_tokenizer=description_tokenizer,
num_samples=8, # TODO: add to args num_samples=8, # TODO: add to args
) )
......
...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2048): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`]. represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=2048, vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos token)
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
...@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02, initializer_factor=0.02,
scale_embedding=False, scale_embedding=False,
num_codebooks=4, num_codebooks=4,
pad_token_id=2048, pad_token_id=2049,
bos_token_id=2048, bos_token_id=2048,
eos_token_id=None, eos_token_id=2049,
tie_word_embeddings=False, tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
......
...@@ -731,7 +731,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -731,7 +731,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# TODO: verify if prompt attention mask is required # TODO: verify if prompt attention mask is required and has to be
# As it is, the masked ids from the prompt will still count in the positions embeddings # As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None: if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1) attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
...@@ -754,6 +754,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -754,6 +754,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
) )
# embed positions # embed positions
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings
# maybe should modify position embeddings
positions = self.embed_positions(inputs_embeds, past_key_values_length) positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
...@@ -1064,7 +1066,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1064,7 +1066,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None: if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
pad_token_id=self.generation_config.pad_token_id, bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
max_length=self.generation_config.max_length, max_length=self.generation_config.max_length,
) )
...@@ -1108,22 +1111,22 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1108,22 +1111,22 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
} }
# Ignore copy # Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None): def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, eos_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`: seq_len)`:
- [P, -1, -1, -1, -1, P, P, P] - [B, -1, -1, -1, -1, E, E, E]
- [P, P, -1, -1, -1, -1, P, P] - [B, B, -1, -1, -1, -1, E, E]
- [P, P, P, -1, -1, -1, -1, P] - [B, B, B, -1, -1, -1, -1, E]
- [P, P, P, P, -1, -1, -1, -1] - [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include where B is the BOS token id, E is the EOS token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt: mask is set to the value in the prompt:
- [P, a, b, -1, -1, P, P, P] - [B, a, b, -1, -1, E, E, E]
- [P, P, c, d, -1, -1, P, P] - [B, B, c, d, -1, -1, E, E]
- [P, P, P, e, f, -1, -1, P] - [B, B, B, e, f, -1, -1, E]
- [P, P, P, P, g, h, -1, -1] - [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction. tokens in our prediction.
""" """
...@@ -1147,14 +1150,16 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1147,14 +1150,16 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# construct a pattern mask that indicates the positions of padding tokens for each codebook # construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding) # first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu( eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1 torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
) )
# then fill the lower triangular part (the BOS padding) # then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool)) bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
mask = ~delay_pattern.to(input_ids.device) bos_mask = ~(bos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * eos_token_id
# find the first position to start generating - this is the first place we have the -1 token # find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset) # and will always be in the first codebook (since it has no codebook offset)
...@@ -1334,7 +1339,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1334,7 +1339,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
pad_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
...@@ -1436,9 +1442,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1436,9 +1442,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( output_ids = output_ids[(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
batch_size, self.num_codebooks, -1 batch_size, self.decoder.num_codebooks, -1
) )
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
...@@ -1919,7 +1925,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1919,7 +1925,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if prompt_hidden_states is None: if prompt_hidden_states is None:
if prompt_input_ids is not None: if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids) prompt_hidden_states = self.embed_prompts(prompt_input_ids)
# TODO: do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states? # TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
...@@ -1999,7 +2005,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1999,7 +2005,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None: if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids, decoder_input_ids,
self.generation_config.pad_token_id, bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
max_length=self.generation_config.max_length, max_length=self.generation_config.max_length,
) )
...@@ -2428,7 +2435,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2428,7 +2435,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
pad_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
# stash the delay mask so that we don't have to recompute in each forward pass # stash the delay mask so that we don't have to recompute in each forward pass
...@@ -2531,8 +2539,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2531,8 +2539,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1 batch_size, self.decoder.num_codebooks, -1
) )
...@@ -2543,10 +2551,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2543,10 +2551,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if audio_scales is None: if audio_scales is None:
audio_scales = [None] * batch_size audio_scales = [None] * batch_size
output_values = self.audio_encoder.decode( decode_in_batch = ((output_ids == generation_config.bos_token_id).sum() + (output_ids == generation_config.eos_token_id).sum()) == 0
output_ids, if decode_in_batch.item():
audio_scales=audio_scales, output_values = self.audio_encoder.decode(
).audio_values output_ids,
audio_scales=audio_scales,
).audio_values
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)).sum(dim=(0,1)) == 0)
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_values outputs.sequences = output_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment