Commit d0140745 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

update code: fix accelerate, fix delay pattern mask, improve generation

parent ee12a812
...@@ -36,8 +36,8 @@ ...@@ -36,8 +36,8 @@
"preprocessing_num_workers": 1, "preprocessing_num_workers": 1,
"pad_token_id": 2050, "pad_token_id": 2048,
"decoder_start_token_id": 2048, "decoder_start_token_id": 2049,
"do_train": true, "do_train": true,
"num_train_epochs": 120, "num_train_epochs": 120,
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
"description_column_name": "text_description", "description_column_name": "text_description",
"prompt_column_name": "text", "prompt_column_name": "text",
"max_train_samples": 12, "max_train_samples": 4,
"max_eval_samples": 12, "max_eval_samples": 4,
"max_duration_in_seconds": 30, "max_duration_in_seconds": 30,
...@@ -36,14 +36,14 @@ ...@@ -36,14 +36,14 @@
"preprocessing_num_workers": 1, "preprocessing_num_workers": 1,
"pad_token_id": 2050, "pad_token_id": 2048,
"decoder_start_token_id": 2048, "decoder_start_token_id": 2049,
"do_train": true, "do_train": true,
"num_train_epochs": 20, "num_train_epochs": 120,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"gradient_checkpointing": false, "gradient_checkpointing": false,
"per_device_train_batch_size": 3, "per_device_train_batch_size": 2,
"learning_rate": 1e-3, "learning_rate": 1e-3,
"adam_beta1": 0.9, "adam_beta1": 0.9,
"adam_beta2": 0.999, "adam_beta2": 0.999,
...@@ -60,10 +60,10 @@ ...@@ -60,10 +60,10 @@
"predict_with_generate": true, "predict_with_generate": true,
"include_inputs_for_metrics": true, "include_inputs_for_metrics": true,
"evaluation_strategy": "steps", "evaluation_strategy": "steps",
"eval_steps": 10, "eval_steps": 30,
"per_device_eval_batch_size": 3, "per_device_eval_batch_size": 2,
"generation_max_length": 400, "generation_max_length": 400,
"do_sample": true, "do_sample": false,
"logging_steps": 15, "logging_steps": 15,
......
...@@ -34,9 +34,9 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( ...@@ -34,9 +34,9 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048 model.generation_config.decoder_start_token_id = 2049
model.generation_config.pad_token_id = 2050 model.generation_config.pad_token_id = 2048
model.generation_config.eos_token_id = 2049 model.generation_config.eos_token_id = 2048
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
......
...@@ -18,7 +18,7 @@ decoder_config = StableSpeechDecoderConfig( ...@@ -18,7 +18,7 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config) decoder = StableSpeechForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/decoder/") decoder.save_pretrained("/home/yoach/dataspeech/artefacts/decoder/")
t5 = AutoConfig.from_pretrained("t5-base") t5 = AutoConfig.from_pretrained("t5-base")
...@@ -26,18 +26,18 @@ t5 = AutoConfig.from_pretrained("t5-base") ...@@ -26,18 +26,18 @@ t5 = AutoConfig.from_pretrained("t5-base")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path="t5-base", text_encoder_pretrained_model_name_or_path="t5-base",
audio_encoder_pretrained_model_name_or_path="facebook/encodec_32khz", audio_encoder_pretrained_model_name_or_path="facebook/encodec_32khz",
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/decoder/", decoder_pretrained_model_name_or_path="/home/yoach/dataspeech/artefacts/decoder/",
vocab_size = t5.vocab_size vocab_size = t5.vocab_size
) )
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048 model.generation_config.decoder_start_token_id = 2049
model.generation_config.pad_token_id = 2050 model.generation_config.pad_token_id = 2048
model.generation_config.eos_token_id = 2049 model.generation_config.eos_token_id = 2048
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/small-stable-speech-untrained/") model.save_pretrained("/home/yoach/dataspeech/artefacts/small-stable-speech-untrained/")
\ No newline at end of file \ No newline at end of file
...@@ -26,6 +26,8 @@ import shutil ...@@ -26,6 +26,8 @@ import shutil
import warnings import warnings
import math import math
import time import time
from multiprocess import set_start_method
import evaluate import evaluate
from tqdm import tqdm from tqdm import tqdm
...@@ -63,7 +65,7 @@ from accelerate import Accelerator ...@@ -63,7 +65,7 @@ from accelerate import Accelerator
from accelerate.utils import set_seed from accelerate.utils import set_seed
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
if is_wandb_available(): if is_wandb_available():
from wandb import Audio from wandb import Audio
...@@ -516,15 +518,10 @@ class DataCollatorStableSpeechWithPadding: ...@@ -516,15 +518,10 @@ class DataCollatorStableSpeechWithPadding:
# (bsz, seq_len, num_codebooks) # (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100) labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
delay_pattern_mask = [torch.tensor(feature["label_delay_pattern_mask"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks)
delay_pattern_mask = torch.nn.utils.rnn.pad_sequence(delay_pattern_mask,batch_first=True,padding_value=-100)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, "label_delay_pattern_mask":delay_pattern_mask, **input_ids} batch= {"labels":labels, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
...@@ -1014,23 +1011,30 @@ def main(): ...@@ -1014,23 +1011,30 @@ def main():
len_ = int(all_ratios[idx] * all_lens[idx]) len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_] labels = labels[:, :, :len_]
# TODO: remove, only for test labels = labels[:, :, :(len_)%10+20] # TODO: change
labels = labels[:, :, :(len_)%10+20]
# add bos and eos token column # add bos
labels = torch.cat([bos_labels,labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = build_delay_pattern_mask(labels,
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
bos_token_id=audio_encoder_bos_token_id, bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_pad_token_id, pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks) max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask) # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask==-1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:] # we also remove the last timestampts (full of PAD)
sample["label_delay_pattern_mask"] = delay_pattern_mask[:, 1:] sample["labels"] = labels[:, 1:].cpu()
return sample return sample
# TODO: done multiple times, how to deal with it. # TODO: done multiple times, how to deal with it.
...@@ -1047,12 +1051,6 @@ def main(): ...@@ -1047,12 +1051,6 @@ def main():
del generate_labels del generate_labels
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if is_wandb_available():
from transformers.integrations import WandbCallback
else:
raise ValueError("`args.add_audio_samples_to_wandb=True` but wandb is not installed. See https://docs.wandb.ai/quickstart to install.")
# for large datasets it is advised to run the preprocessing on a # for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely # single machine first with ``args.preprocessing_only`` since there will mostly likely
...@@ -1467,4 +1465,5 @@ def main(): ...@@ -1467,4 +1465,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
set_start_method("spawn")
main() main()
\ No newline at end of file
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
\ No newline at end of file \ No newline at end of file
...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2050): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`]. represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=2050, # vocab size = 2048 (encodec vocab size) + 2 (bos, eos) vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
...@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02, initializer_factor=0.02,
scale_embedding=False, scale_embedding=False,
num_codebooks=4, num_codebooks=4,
pad_token_id=2050, pad_token_id=2048,
bos_token_id=2048, bos_token_id=2049,
eos_token_id=2049, eos_token_id=2048,
tie_word_embeddings=False, tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
......
...@@ -60,6 +60,77 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -60,6 +60,77 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech # See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
] ]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1]
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = (
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
@dataclass @dataclass
class StableSpeechUnconditionalInput(ModelOutput): class StableSpeechUnconditionalInput(ModelOutput):
...@@ -982,7 +1053,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -982,7 +1053,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1031,16 +1101,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1031,16 +1101,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks) # (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100) labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100) # we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1,2) != self.config.eos_token_id) & ((labels != -100))
mask = (labels != -100)
# per codebook cross-entropy # per codebook cross-entropy
for codebook in range(self.config.num_codebooks): for codebook in range(self.config.num_codebooks):
...@@ -1152,60 +1215,14 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1152,60 +1215,14 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction. tokens in our prediction.
""" """
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
max_length = max_length if max_length is not None else self.generation_config.max_length max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = ( return build_delay_pattern_mask(input_ids, bos_token_id, pad_token_id, max_length, self.num_codebooks)
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
@staticmethod @staticmethod
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask.""" the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1] return apply_delay_pattern_mask(input_ids, decoder_pad_token_mask)
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
@torch.no_grad() @torch.no_grad()
def generate( def generate(
...@@ -1219,7 +1236,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1219,7 +1236,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
**kwargs, **kwargs,
): ):
""" """
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head. Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}> <Tip warning={true}>
...@@ -1868,7 +1885,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1868,7 +1885,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1991,7 +2007,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1991,7 +2007,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values, past_key_values=past_key_values,
return_dict=return_dict, return_dict=return_dict,
labels=labels, labels=labels,
label_delay_pattern_mask=label_delay_pattern_mask,
**kwargs_decoder, **kwargs_decoder,
) )
...@@ -2074,6 +2089,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2074,6 +2089,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"use_cache": use_cache, "use_cache": use_cache,
} }
...@@ -2564,9 +2581,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2564,9 +2581,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# TODO: probably won't work... _, mask = self.decoder.build_delay_pattern_mask(
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.pad_token_id)].reshape( input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id)&(mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(
batch_size, self.decoder.num_codebooks, -1 batch_size, self.decoder.num_codebooks, -1
) )
...@@ -2577,8 +2602,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2577,8 +2602,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if audio_scales is None: if audio_scales is None:
audio_scales = [None] * batch_size audio_scales = [None] * batch_size
decode_in_batch = ((output_ids == generation_config.bos_token_id).sum() + (output_ids == generation_config.eos_token_id).sum()) == 0 decode_sequentially = generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids
if decode_in_batch.item(): if not decode_sequentially:
output_values = self.audio_encoder.decode( output_values = self.audio_encoder.decode(
output_ids, output_ids,
audio_scales=audio_scales, audio_scales=audio_scales,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment