Commit eec6e3d6 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix bandwidth

parent dbb95132
......@@ -256,6 +256,10 @@ class ModelArguments:
default=400, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."},
)
bandwidth: float = field(
default=3, # TODO
metadata={"help": "Audio encoder bandwidth."},
)
......@@ -909,6 +913,7 @@ def main():
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth
# resample target audio
raw_datasets = raw_datasets.cast_column(
......@@ -971,7 +976,7 @@ def main():
len_audio = batch.pop("len_audio")
audio_decoder.to(batch["input_values"].device).eval()
with torch.no_grad():
labels = audio_decoder.encode(**batch)["audio_codes"]
labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
output = {}
output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
......@@ -1011,7 +1016,7 @@ def main():
len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_]
# labels = labels[:, :, :(len_)%10+20] # TODO: change
# labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
......@@ -1047,6 +1052,7 @@ def main():
input_columns=["input_ids", "prompt_input_ids"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=200,
)
......@@ -1065,22 +1071,6 @@ def main():
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# 6. Next, we can prepare the training.
......@@ -1120,7 +1110,6 @@ def main():
def compute_metrics(audios, descriptions, prompts, device="cpu"):
input_ids = descriptions
input_ids[input_ids==-100] = description_tokenizer.pad_token_id
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
......@@ -1171,7 +1160,7 @@ def main():
lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
num_training_steps=total_train_steps * accelerator.num_processes,
)
......@@ -1231,6 +1220,22 @@ def main():
os.makedirs(training_args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
if checkpoint is not None:
accelerator.load_state(checkpoint)
......
......@@ -212,6 +212,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed
if seq_len > self.weights.size(0):
# TODO: doesn't work
self.make_weights(seq_len + self.offset, self.embedding_dim)
return self.weights.index_select(0, position_ids.view(-1)).detach()
......@@ -2620,7 +2621,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
else:
output_values.append(torch.zeros((1,1,1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1).squeeze(1)
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).squeeze(-1).squeeze(-1)
if generation_config.return_dict_in_generate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment