Commit eec6e3d6 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix bandwidth

parent dbb95132
...@@ -256,6 +256,10 @@ class ModelArguments: ...@@ -256,6 +256,10 @@ class ModelArguments:
default=400, # TODO default=400, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."}, metadata={"help": "Whether to do sampling or greedy decoding."},
) )
bandwidth: float = field(
default=3, # TODO
metadata={"help": "Audio encoder bandwidth."},
)
...@@ -909,6 +913,7 @@ def main(): ...@@ -909,6 +913,7 @@ def main():
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth
# resample target audio # resample target audio
raw_datasets = raw_datasets.cast_column( raw_datasets = raw_datasets.cast_column(
...@@ -971,7 +976,7 @@ def main(): ...@@ -971,7 +976,7 @@ def main():
len_audio = batch.pop("len_audio") len_audio = batch.pop("len_audio")
audio_decoder.to(batch["input_values"].device).eval() audio_decoder.to(batch["input_values"].device).eval()
with torch.no_grad(): with torch.no_grad():
labels = audio_decoder.encode(**batch)["audio_codes"] labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
output = {} output = {}
output["len_audio"] = len_audio output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks) # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
...@@ -1011,7 +1016,7 @@ def main(): ...@@ -1011,7 +1016,7 @@ def main():
len_ = int(all_ratios[idx] * all_lens[idx]) len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_] labels = labels[:, :, :len_]
# labels = labels[:, :, :(len_)%10+20] # TODO: change # labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos # add bos
labels = torch.cat([bos_labels, labels], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
...@@ -1047,6 +1052,7 @@ def main(): ...@@ -1047,6 +1052,7 @@ def main():
input_columns=["input_ids", "prompt_input_ids"], input_columns=["input_ids", "prompt_input_ids"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
with_indices=True, with_indices=True,
writer_batch_size=200,
) )
...@@ -1065,22 +1071,6 @@ def main(): ...@@ -1065,22 +1071,6 @@ def main():
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return return
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
...@@ -1120,7 +1110,6 @@ def main(): ...@@ -1120,7 +1110,6 @@ def main():
def compute_metrics(audios, descriptions, prompts, device="cpu"): def compute_metrics(audios, descriptions, prompts, device="cpu"):
input_ids = descriptions input_ids = descriptions
input_ids[input_ids==-100] = description_tokenizer.pad_token_id
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.cpu().numpy() for a in audios]
...@@ -1171,7 +1160,7 @@ def main(): ...@@ -1171,7 +1160,7 @@ def main():
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type, name=training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps * accelerator.num_processes, num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
num_training_steps=total_train_steps * accelerator.num_processes, num_training_steps=total_train_steps * accelerator.num_processes,
) )
...@@ -1231,6 +1220,22 @@ def main(): ...@@ -1231,6 +1220,22 @@ def main():
os.makedirs(training_args.output_dir, exist_ok=True) os.makedirs(training_args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
if checkpoint is not None: if checkpoint is not None:
accelerator.load_state(checkpoint) accelerator.load_state(checkpoint)
......
...@@ -212,6 +212,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module): ...@@ -212,6 +212,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed # expand embeddings if needed
if seq_len > self.weights.size(0): if seq_len > self.weights.size(0):
# TODO: doesn't work
self.make_weights(seq_len + self.offset, self.embedding_dim) self.make_weights(seq_len + self.offset, self.embedding_dim)
return self.weights.index_select(0, position_ids.view(-1)).detach() return self.weights.index_select(0, position_ids.view(-1)).detach()
...@@ -2620,7 +2621,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2620,7 +2621,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
else: else:
output_values.append(torch.zeros((1,1,1)).to(self.device)) output_values.append(torch.zeros((1,1,1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1).squeeze(1) output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).squeeze(-1).squeeze(-1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment