Commit 0f6d59d4 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

latest changes

parent 11fcc066
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
"description_column_name": "text_description", "description_column_name": "text_description",
"prompt_column_name": "text", "prompt_column_name": "text",
"max_train_samples": 1000, "max_train_samples": 20,
"max_eval_samples": 200, "max_eval_samples": 10,
"max_duration_in_seconds": 20, "max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0, "min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true, "add_audio_samples_to_wandb": true,
...@@ -36,30 +36,36 @@ ...@@ -36,30 +36,36 @@
"preprocessing_num_workers": 1, "preprocessing_num_workers": 1,
"pad_token_id": 2049, "pad_token_id": 2050,
"decoder_start_token_id": 2048, "decoder_start_token_id": 2048,
"do_train": true, "do_train": true,
"num_train_epochs": 1, "num_train_epochs": 120,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"gradient_checkpointing": true, "gradient_checkpointing": false,
"per_device_train_batch_size": 8, "per_device_train_batch_size": 2,
"learning_rate": 1e-6, "learning_rate": 1e-3,
"adam_beta1": 0.9, "adam_beta1": 0.9,
"adam_beta2": 0.95, "adam_beta2": 0.999,
"weight_decay": 0.1, "weight_decay": 0.1,
"logging_steps": 25, "lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"logging_steps": 1,
"freeze_text_encoder": true,
"do_eval": true, "do_eval": true,
"predict_with_generate": true, "predict_with_generate": true,
"include_inputs_for_metrics": true, "include_inputs_for_metrics": true,
"evaluation_strategy": "epoch", "evaluation_strategy": "steps",
"eval_steps": 600,
"per_device_eval_batch_size": 8, "per_device_eval_batch_size": 8,
"generation_max_length": 400, "generation_max_length": 400,
"fp16": true, "fp16": false,
"seed": 456, "seed": 456,
"dataloader_num_workers":8 "dataloader_num_workers":8
......
{
"model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/",
"feature_extractor_name":"facebook/encodec_24khz",
"description_tokenizer_name":"t5-base",
"prompt_tokenizer_name":"t5-base",
"push_to_hub": false,
"hub_model_id": "stable-speech-mini",
"report_to": ["wandb"],
"overwrite_output_dir": true,
"output_dir": "/home/yoach/dataspeech/artefacts/training/",
"train_dataset_name": "blabble-io/libritts_r",
"train_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"train_dataset_config_name": "clean",
"train_split_name": "train.clean.360",
"eval_dataset_name": "blabble-io/libritts_r",
"eval_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"eval_dataset_config_name": "clean",
"eval_split_name": "train.clean.360",
"target_audio_column_name": "audio",
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 12,
"max_eval_samples": 12,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true,
"id_column_name": "id",
"preprocessing_num_workers": 1,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 20,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"per_device_train_batch_size": 3,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"weight_decay": 0.1,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 10,
"per_device_eval_batch_size": 3,
"generation_max_length": 400,
"do_sample": true,
"logging_steps": 15,
"dtype": "float32",
"seed": 456,
"dataloader_num_workers":8
}
...@@ -4,16 +4,16 @@ from transformers import AutoConfig ...@@ -4,16 +4,16 @@ from transformers import AutoConfig
decoder_config = StableSpeechDecoderConfig( decoder_config = StableSpeechDecoderConfig(
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=2, num_hidden_layers=4,
ffn_dim=256, ffn_dim=512,
num_attention_heads=4, num_attention_heads=8,
layerdrop=0.0, layerdrop=0.0,
use_cache=True, use_cache=True,
activation_function="gelu", activation_function="gelu",
hidden_size=256, hidden_size=512,
dropout=0.1, dropout=0.0,
attention_dropout=0.1, attention_dropout=0.0,
activation_dropout=0.1, activation_dropout=0.0,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
...@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( ...@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048 model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2049 model.generation_config.pad_token_id = 2050
model.generation_config.eos_token_id = 2049 model.generation_config.eos_token_id = 2049
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/home/yoach/dataspeech/artefacts/tiny-model/") model.save_pretrained("/home/yoach/dataspeech/artefacts/tiny-model/")
\ No newline at end of file
This diff is collapsed.
...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2050):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`]. represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos token) vocab_size=2050, # vocab size = 2048 (encodec vocab size) + 2 (bos, eos)
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
...@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02, initializer_factor=0.02,
scale_embedding=False, scale_embedding=False,
num_codebooks=4, num_codebooks=4,
pad_token_id=2049, pad_token_id=2050,
bos_token_id=2048, bos_token_id=2048,
eos_token_id=2049, eos_token_id=2049,
tie_word_embeddings=False, tie_word_embeddings=False,
......
...@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self.num_codebooks = config.num_codebooks self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_dim = config.vocab_size + 1 # TODO: not right dim
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList( self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
) )
...@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -991,6 +993,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -991,6 +993,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# TODO: delay_pattern_mask
Returns: Returns:
""" """
...@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:] logits = lm_logits[:,:,-labels.shape[1]:]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks) # (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100) labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss = loss_fct(logits.transpose(1,3), labels) # loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask = (labels != -100)
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
...@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None: if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=self.generation_config.decoder_start_token_id, bos_token_id=self.generation_config.bos_token_id,
eos_token_id=self.generation_config.eos_token_id, pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length, max_length=self.generation_config.max_length,
) )
...@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
} }
# Ignore copy # Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, eos_token_id: int, max_length: int = None): def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`: seq_len)`:
- [B, -1, -1, -1, -1, E, E, E] - [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, E, E] - [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, E] - [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1] - [B, B, B, B, -1, -1, -1, -1]
where B is the BOS token id, E is the EOS token id and -1 indicates that the token is valid for prediction. If we include where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt: mask is set to the value in the prompt:
- [B, a, b, -1, -1, E, E, E] - [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, E, E] - [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, E] - [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1] - [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction. tokens in our prediction.
...@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
bos_mask = ~(bos_delay_pattern).to(input_ids.device) bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device) eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device) mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * eos_token_id input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token # find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset) # and will always be in the first codebook (since it has no codebook offset)
...@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id,
eos_token_id=generation_config.eos_token_id, pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
...@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# TODO: verify prompt_attention_mask # TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) ).transpose(1,2)
elif decoder_input_ids is None and decoder_inputs_embeds is None: elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder( audio_encoder_outputs = self.audio_encoder(
...@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values, past_key_values=past_key_values,
return_dict=return_dict, return_dict=return_dict,
labels=labels, labels=labels,
label_delay_pattern_mask=label_delay_pattern_mask,
**kwargs_decoder, **kwargs_decoder,
) )
...@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None: if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids, decoder_input_ids,
bos_token_id=self.generation_config.decoder_start_token_id, bos_token_id=self.generation_config.bos_token_id,
eos_token_id=self.generation_config.eos_token_id, pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length, max_length=self.generation_config.max_length,
) )
...@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings # TODO: now it's possible with prompt_embeddings
...@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id,
eos_token_id=generation_config.eos_token_id, pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
# stash the delay mask so that we don't have to recompute in each forward pass # stash the delay mask so that we don't have to recompute in each forward pass
...@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.eos_token_id)].reshape( # TODO: probably won't work...
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.pad_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1 batch_size, self.decoder.num_codebooks, -1
) )
...@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values = self.audio_encoder.decode( output_values = self.audio_encoder.decode(
output_ids, output_ids,
audio_scales=audio_scales, audio_scales=audio_scales,
).audio_values ).audio_values.squeeze(1)
else: else:
output_values = [] output_values = []
for sample_id in range(batch_size): for sample_id in range(batch_size):
sample = output_ids[:, sample_id] sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)).sum(dim=(0,1)) == 0) sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)|(sample == generation_config.pad_token_id)).sum(dim=(0,1)) == 0)
if sample_mask.sum()>0:
sample = sample[:, :, sample_mask] sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2)) output_values.append(sample.transpose(0,2))
else:
output_values.append(torch.zeros((1,1,1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1) output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1).squeeze(1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment