"include/vscode:/vscode.git/clone" did not exist on "63fd5da63789ac59d9f4ebeefc38ba8397bc8a27"
Unverified Commit 9232a47b authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #53 from ylacombe/nits-improvements

[Training] Small nits
parents 5518cc2f a0bc9e78
...@@ -218,7 +218,7 @@ class DataTrainingArguments: ...@@ -218,7 +218,7 @@ class DataTrainingArguments:
metadata={ metadata={
"help": ( "help": (
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens." "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`." "Also, used to set maximum description token length if `pad_to_max_length=True`."
) )
}, },
) )
...@@ -277,6 +277,12 @@ class DataTrainingArguments: ...@@ -277,6 +277,12 @@ class DataTrainingArguments:
default="parler-speech", default="parler-speech",
metadata={"help": "The name of the wandb project."}, metadata={"help": "The name of the wandb project."},
) )
wandb_run_name: str = field(
default=None,
metadata={
"help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
},
)
save_to_disk: str = field( save_to_disk: str = field(
default=None, default=None,
metadata={ metadata={
......
...@@ -31,7 +31,7 @@ class DataCollatorEncodecWithPadding: ...@@ -31,7 +31,7 @@ class DataCollatorEncodecWithPadding:
audios = [feature[self.audio_column_name]["array"] for feature in features] audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios] len_audio = [len(audio) for audio in audios]
# since resampling has already been performed in the 'load_multiple_datasets' function, # since resampling has already been performed in the 'load_multiple_datasets' function,
# a fixed sampling_rate(44100hz) is passed to the feature_extractor. # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
sampling_rate = self.feature_extractor.sampling_rate sampling_rate = self.feature_extractor.sampling_rate
batch = self.feature_extractor( batch = self.feature_extractor(
......
...@@ -98,9 +98,6 @@ def main(): ...@@ -98,9 +98,6 @@ def main():
####### A. Preparation ####### A. Preparation
kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))] kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
if training_args.torch_compile:
# TODO(YL): add more compile modes?
kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
...@@ -129,6 +126,7 @@ def main(): ...@@ -129,6 +126,7 @@ def main():
"adam_beta2": training_args.adam_beta2, "adam_beta2": training_args.adam_beta2,
"temperature": model_args.temperature, "temperature": model_args.temperature,
}, },
init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else None,
) )
# Detecting last checkpoint and eventually continue from last checkpoint # Detecting last checkpoint and eventually continue from last checkpoint
...@@ -538,7 +536,7 @@ def main(): ...@@ -538,7 +536,7 @@ def main():
logger.info(f"Dataset saved at {data_args.save_to_disk}") logger.info(f"Dataset saved at {data_args.save_to_disk}")
audio_max_length = None audio_max_length = None
if training_args.torch_compile: if padding == "max_length":
audio_max_length = max(vectorized_datasets["train"]["target_length"]) audio_max_length = max(vectorized_datasets["train"]["target_length"])
with accelerator.main_process_first(): with accelerator.main_process_first():
max_sample = vectorized_datasets["train"].filter( max_sample = vectorized_datasets["train"].filter(
...@@ -548,6 +546,18 @@ def main(): ...@@ -548,6 +546,18 @@ def main():
) )
audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1] audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
if training_args.group_by_length:
# apply a simple heuristic to take into account audio and text lengths
def add_target_lengths(target_length, prompt, description):
return {"target_length": target_length + len(prompt) + len(description)}
with accelerator.main_process_first():
vectorized_datasets = vectorized_datasets.map(
add_target_lengths,
num_proc=num_workers,
input_columns=["target_length", "prompt_input_ids", "input_ids"],
)
# for large datasets it is advised to run the preprocessing on a # for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely # single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode. # be a timeout when running the script in distributed mode.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment