Commit 5acad845 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

add possibility to precompute text hidden states + fix generation

parent 80da6b4c
......@@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from transformers import AutoConfig, AutoModel
from stable_speech import DACConfig, DACModel
from transformers.modeling_outputs import BaseModelOutput
AutoConfig.register("dac", DACConfig)
AutoModel.register(DACConfig, DACModel)
......@@ -464,6 +464,14 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
)
},
)
text_encode_per_device_eval_batch_size: int = field(
default=8,
metadata={
"help": (
"TODO"
)
},
)
@dataclass
class DataCollatorEncodecWithPadding:
......@@ -531,9 +539,16 @@ class DataCollatorStableSpeechWithPadding:
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids}
if "encoder_outputs" in features[0]:
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding)
encoder_hidden_states = [torch.tensor(feature["encoder_outputs"]["last_hidden_state"]) for feature in features]
encoder_hidden_states = torch.nn.utils.rnn.pad_sequence(encoder_hidden_states,batch_first=True,padding_value=0.)
batch= {"labels":labels, "encoder_outputs": BaseModelOutput(last_hidden_state=encoder_hidden_states), **input_ids}
else:
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
......@@ -550,7 +565,40 @@ class DataCollatorStableSpeechWithPadding:
batch[self.feature_extractor_input_name: input_values]
return batch
@dataclass
class T5TextCollatorStableSpeechWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
description_tokenizer: AutoTokenizer
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids_len = [len(feature["input_ids"]) for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"len_input_ids": torch.tensor(input_ids_len), **input_ids}
return batch
def convert_dataset_str_to_list(
dataset_names,
......@@ -729,6 +777,7 @@ def main():
else:
mixed_precision = "no"
####### A. Preparation
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
......@@ -926,6 +975,9 @@ def main():
trust_remote_code=data_args.trust_remote_code,
)
# enable gradient checkpointing if necessary
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
......@@ -947,6 +999,9 @@ def main():
num_codebooks = model.decoder.config.num_codebooks
bandwidth = model_args.bandwidth
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
if not dataset_was_precomputed:
# resample target audio
raw_datasets = raw_datasets.cast_column(
......@@ -996,6 +1051,8 @@ def main():
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
####### B. Encode audio
logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision
......@@ -1090,6 +1147,52 @@ def main():
accelerator.free_memory()
del generate_labels
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
####### C. Encode text if text encoder is freezed
if model_args.freeze_text_encoder:
text_data_collator = T5TextCollatorStableSpeechWithPadding(description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of)
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=training_args.text_encode_per_device_eval_batch_size,
collate_fn=text_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_encoder_outputs = []
all_encoder_lengths = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.text_encoder.to(batch["input_ids"].device)
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id)
encoder_outputs = accelerator.gather_for_metrics(encoder_outputs)
lengths = accelerator.gather_for_metrics(batch["len_input_ids"])
# TODO: check it works multi device
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
def postprocess_dataset(input_ids, idx):
output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])}
return output
# TODO: done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=1, # this one is resource consuming if many processor.
input_columns=["input_ids"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=100,
)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
......@@ -1110,11 +1213,6 @@ def main():
# 6. Next, we can prepare the training.
# enable gradient checkpointing if necessary
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
......@@ -1126,14 +1224,15 @@ def main():
def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None))
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
with torch.no_grad():
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None))
audio_features = clap.get_audio_features(clap_inputs["input_features"])
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean()
return cosine_sim.mean().to("cpu")
def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
......@@ -1186,6 +1285,9 @@ def main():
else:
eval_steps = training_args.eval_steps
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
# Define optimizer, LR scheduler, collator
optimizer = torch.optim.AdamW(
params=model.parameters(),
......@@ -1208,8 +1310,6 @@ def main():
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of
)
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
......@@ -1318,10 +1418,13 @@ def main():
):
model.train()
if mixed_precision == "fp16":
if mixed_precision == "fp16" and not model_args.freeze_text_encoder:
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
if training_args.parallel_mode.value != "distributed":
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
batch["encoder_outputs"] = encoder_outputs
outputs = model(**batch)
......@@ -1335,10 +1438,13 @@ def main():
# Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,):
model.eval()
if mixed_precision == "fp16":
if mixed_precision == "fp16" and not model_args.freeze_text_encoder:
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
if training_args.parallel_mode.value != "distributed":
encoder_outputs = model.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
else:
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
batch["encoder_outputs"] = encoder_outputs
with torch.no_grad():
......@@ -1354,7 +1460,6 @@ def main():
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
autocast_kwargs = AutocastKwargs(enabled=False)
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO: add args
......@@ -1471,9 +1576,9 @@ def main():
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes((generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0)
generated_audios, input_ids, prompts = accelerator.gather_for_metrics((generated_audios, input_ids, prompts))
eval_preds.extend(generated_audios)
eval_descriptions.extend(input_ids)
eval_prompts.extend(prompts)
eval_preds.extend(generated_audios.to("cpu"))
eval_descriptions.extend(input_ids.to("cpu"))
eval_prompts.extend(prompts.to("cpu"))
eval_time = time.time() - eval_start
# normalize eval metrics
......@@ -1489,16 +1594,17 @@ def main():
)
eval_metrics.update(metric_values)
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
log_pred(
accelerator,
pred_descriptions,
pred_prompts,
transcriptions,
audios,
sampling_rate=sampling_rate,
step=cur_step,
prefix="eval",
)
if "wandb" in training_args.report_to:
log_pred(
accelerator,
pred_descriptions,
pred_prompts,
transcriptions,
audios,
sampling_rate=sampling_rate,
step=cur_step,
prefix="eval",
)
# Print metrics and update progress bar
steps_trained_progress_bar.write(
......
......@@ -2617,7 +2617,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)|(sample == generation_config.pad_token_id)).sum(dim=(0,1)) == 0)
sample_mask = ((sample >= self.audio_encoder.config.codebook_size).sum(dim=(0,1)) == 0)
if sample_mask.sum()>0:
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment