Unverified Commit 9232a47b authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #53 from ylacombe/nits-improvements

[Training] Small nits
parents 5518cc2f a0bc9e78
......@@ -218,7 +218,7 @@ class DataTrainingArguments:
metadata={
"help": (
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
"Also, used to set maximum description token length if `pad_to_max_length=True`."
)
},
)
......@@ -277,6 +277,12 @@ class DataTrainingArguments:
default="parler-speech",
metadata={"help": "The name of the wandb project."},
)
wandb_run_name: str = field(
default=None,
metadata={
"help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
},
)
save_to_disk: str = field(
default=None,
metadata={
......
......@@ -31,7 +31,7 @@ class DataCollatorEncodecWithPadding:
audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios]
# since resampling has already been performed in the 'load_multiple_datasets' function,
# since resampling has already been performed in the 'load_multiple_datasets' function,
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
sampling_rate = self.feature_extractor.sampling_rate
batch = self.feature_extractor(
......
......@@ -98,9 +98,6 @@ def main():
####### A. Preparation
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile:
# TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
......@@ -129,6 +126,7 @@ def main():
"adam_beta2": training_args.adam_beta2,
"temperature": model_args.temperature,
},
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else None,
)
# Detecting last checkpoint and eventually continue from last checkpoint
......@@ -538,7 +536,7 @@ def main():
logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None
if training_args.torch_compile:
if padding == "max_length":
audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first():
max_sample = vectorized_datasets["train"].filter(
......@@ -548,6 +546,18 @@ def main():
)
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
if training_args.group_by_length:
# apply a simple heuristic to take into account audio and text lengths
def add_target_lengths(target_length, prompt, description):
return {"target_length": target_length + len(prompt) + len(description)}
with accelerator.main_process_first():
vectorized_datasets = vectorized_datasets.map(
add_target_lengths,
num_proc=num_workers,
input_columns=["target_length", "prompt_input_ids", "input_ids"],
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment