"cpp_onnx/vscode:/vscode.git/clone" did not exist on "46fc6fee4abf38bc3c17f45754d4f29c728accf6"
Commit 82cbc3ad authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

add torch compile compatibility + remove precompute_text_hidden_states

parent 1fe3fc1e
......@@ -70,7 +70,7 @@ AutoModel.register(DACConfig, DACModel)
from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
......@@ -247,17 +247,13 @@ class ModelArguments:
metadata={"help": "Temperature if sampling."},
)
max_length: int = field(
default=1500, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."},
default=2580,
metadata={"help": "Generation max length."},
)
bandwidth: float = field(
default=6, # TODO
metadata={"help": "Audio encoder bandwidth."},
)
precompute_text_hidden_states: bool = field(
default=False,
metadata={"help": "Whether to precompute text hidden states. Only work when the text encoder is freezed"},
)
......@@ -374,8 +370,8 @@ class DataTrainingArguments:
default=35.0,
metadata={
"help": (
"Filter audio files that are longer than `max_duration_in_seconds` seconds to"
" 'max_duration_in_seconds`"
"Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
"Also, used to set maximum audio length if `pad_to_max_length=True`."
)
},
)
......@@ -383,7 +379,31 @@ class DataTrainingArguments:
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
)
max_text_length: int = field(
default=500, metadata={"help": "Max description lengths in number of characters."}
default=500, metadata={"help": "If set, max description lengths in number of characters."}
)
max_prompt_token_length: int = field(
default=None, metadata={
"help": (
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
)
}
)
max_description_token_length: int = field(
default=None, metadata={
"help": (
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
)
}
)
pad_to_max_length: bool = field(
default=False, metadata={
"help": (
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
"`max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`."
)
}
)
preprocessing_only: bool = field(
default=False,
......@@ -490,13 +510,15 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
@dataclass
class DataCollatorEncodecWithPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
Data collator that will dynamically pad the inputs received to the longest sequence in the batch or
to `max_length` if `max_length` is set and `padding=max_length`.
"""
feature_extractor: AutoFeatureExtractor
audio_column_name: str
feature_extractor_input_name: Optional[str] = "input_values"
max_length: Optional[int] = None
padding: Optional[str] = "longest"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
......@@ -504,10 +526,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods
audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios]
if self.max_length is not None:
audios = [audio[:min(len(audio), self.max_length + 10)] for audio in audios]
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest")
batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length)
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
return batch
......@@ -544,6 +564,9 @@ class DataCollatorStableSpeechWithPadding:
feature_extractor_input_name: Optional[str] = "input_values"
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
prompt_max_length: Optional[int] = None
description_max_length: Optional[int] = None
audio_max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
......@@ -553,28 +576,29 @@ class DataCollatorStableSpeechWithPadding:
labels = [torch.tensor(feature["labels"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
if self.audio_max_length is not None and self.padding=="max_length":
labels = torch.nn.functional.pad(labels, pad=(0,0,0,max(self.audio_max_length-labels.shape[1], 0)))
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
if "encoder_outputs" in features[0]:
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding)
encoder_hidden_states = [torch.tensor(feature["encoder_outputs"]["last_hidden_state"]) for feature in features]
encoder_hidden_states = torch.nn.utils.rnn.pad_sequence(encoder_hidden_states,batch_first=True,padding_value=0.)
batch= {"labels":labels, "encoder_outputs": BaseModelOutput(last_hidden_state=encoder_hidden_states), **input_ids}
else:
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids}
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.description_max_length)
batch= {"labels":labels, **input_ids}
if self.audio_max_length is not None and self.padding=="max_length":
# if we do torch.compile, we need to also specify the attention_mask
decoder_attention_mask = torch.ones(labels.shape[:2], dtype=input_ids["attention_mask"].dtype)
batch["decoder_attention_mask"] = decoder_attention_mask
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, max_length=self.prompt_max_length)
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
if self.feature_extractor_input_name in features[0]:
# TODO: verify that it works
# TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values = [{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features]
input_values = self.feature_extractor.pad(input_values, return_tensors="pt")
......@@ -582,40 +606,6 @@ class DataCollatorStableSpeechWithPadding:
return batch
@dataclass
class T5TextCollatorStableSpeechWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
description_tokenizer: AutoTokenizer
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids_len = [len(feature["input_ids"]) for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"len_input_ids": torch.tensor(input_ids_len), **input_ids}
return batch
def convert_dataset_str_to_list(
dataset_names,
dataset_config_names,
......@@ -821,15 +811,24 @@ def main():
mixed_precision = "bf16"
else:
mixed_precision = "no"
if data_args.pad_to_max_length and (data_args.max_duration_in_seconds is None or data_args.max_prompt_token_length is None or data_args.max_description_token_length is None):
raise ValueError("`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`")
padding = "max_length" if data_args.pad_to_max_length else "longest"
####### A. Preparation
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile:
# TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor"))
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=training_args.report_to,
project_dir=training_args.output_dir,
kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(minutes=60))],
kwargs_handlers=kwargs_handlers,
)
accelerator.init_trackers(project_name=data_args.wandb_project, config={
......@@ -1064,7 +1063,7 @@ def main():
if not dataset_was_precomputed:
# Filter on text length
if description_column_name is not None:
if description_column_name is not None and data_args.max_text_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
raw_datasets = raw_datasets.filter(
......@@ -1098,64 +1097,19 @@ def main():
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
####### B. (Optional) Encode text if text encoder is freezed
if model_args.freeze_text_encoder and model_args.precompute_text_hidden_states:
text_data_collator = T5TextCollatorStableSpeechWithPadding(description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of)
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=training_args.text_encode_per_device_eval_batch_size,
collate_fn=text_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_encoder_outputs = []
all_encoder_lengths = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.text_encoder.to(batch["input_ids"].device)
with accelerator.autocast(autocast_handler=autocast_kwargs):
with torch.no_grad():
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id)
encoder_outputs = accelerator.gather_for_metrics(encoder_outputs)
lengths = accelerator.gather_for_metrics(batch["len_input_ids"])
if accelerator.is_main_process:
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
def postprocess_dataset(input_ids, idx):
output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])}
return output
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=1, # this one is resource consuming if many processor.
input_columns=["input_ids"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=100,
)
accelerator.wait_for_everyone()
accelerator.free_memory()
del data_loader, all_encoder_outputs, all_encoder_lengths
# Now we encode the audio labels with encodec.
####### C. Encode audio
####### B. Encode audio
logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder = model.audio_encoder
if training_args.torch_compile:
audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
else:
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length)
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length,padding=padding)
def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio")
......@@ -1251,6 +1205,10 @@ def main():
with accelerator.main_process_first():
# NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
# caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
# That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
......@@ -1260,11 +1218,41 @@ def main():
num_proc=num_workers,
input_columns=["target_length"],
)
if description_column_name is not None and data_args.max_description_token_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_description_token_length,
num_proc=num_workers,
input_columns=["input_ids"],
)
if data_args.max_prompt_token_length is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x) < data_args.max_prompt_token_length,
num_proc=num_workers,
input_columns=["prompt_input_ids"],
)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"])-1))
logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None
if training_args.torch_compile:
audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first():
max_sample = vectorized_datasets["train"].filter(
lambda x: x == audio_max_length,
num_proc=num_workers,
input_columns=["target_length"],
)
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
......@@ -1374,7 +1362,8 @@ def main():
# Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of,
padding=padding, prompt_max_length=data_args.max_prompt_token_length, description_max_length=data_args.max_description_token_length, audio_max_length = audio_max_length
)
......@@ -1485,7 +1474,7 @@ def main():
):
model.train()
if mixed_precision == "fp16" and not (model_args.freeze_text_encoder and model_args.precompute_text_hidden_states):
if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed":
......@@ -1505,7 +1494,7 @@ def main():
# Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,):
model.eval()
if mixed_precision == "fp16" and not (model_args.freeze_text_encoder and model_args.precompute_text_hidden_states):
if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed":
......@@ -1523,6 +1512,7 @@ def main():
def generate_step(batch):
model.eval()
batch.pop("decoder_attention_mask", None)
output_audios = accelerator.unwrap_model(model, keep_fp32_wrapper = mixed_precision != "fp16").generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
......@@ -1626,7 +1616,7 @@ def main():
for batch in tqdm(
validation_dataloader,
desc=f"Evaluating...",
desc=f"Evaluating - Inference ...",
position=2,
disable=not accelerator.is_local_main_process,
):
......@@ -1635,8 +1625,23 @@ def main():
eval_metric = accelerator.gather_for_metrics(eval_metric)
eval_metrics.append(eval_metric)
if training_args.predict_with_generate:
validation_dataloader = DataLoader(
vectorized_datasets["eval"],
collate_fn=data_collator,
batch_size=per_device_eval_batch_size,
drop_last=False,
num_workers=training_args.dataloader_pin_memory,
pin_memory=training_args.dataloader_pin_memory,
)
validation_dataloader = accelerator.prepare(validation_dataloader)
# generation
if training_args.predict_with_generate:
for batch in tqdm(
validation_dataloader,
desc=f"Evaluating - Generation ...",
position=2,
disable=not accelerator.is_local_main_process,
):
generated_audios = generate_step(batch)
# Gather all predictions and targets
# TODO: also add prompt ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment