Commit faef1c72 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix sampling + free gpu memory after eval

parent 03611f97
......@@ -1202,11 +1202,9 @@ def main():
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
sampler = LengthGroupedSampler(per_device_train_batch_size, lengths = vectorized_datasets["train"]["target_length"])
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
logger.info(" Instantaneous batch size per device =" f" {per_device_train_batch_size}")
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
logger.info(
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
......@@ -1341,6 +1339,8 @@ def main():
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO: add args
sampler = LengthGroupedSampler(train_batch_size, lengths = vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader(
vectorized_datasets["train"],
collate_fn=data_collator,
......@@ -1450,7 +1450,7 @@ def main():
# TODO: also add prompt ids
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.pad_across_processes((generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0)
generated_audios, input_ids, prompts =accelerator.gather_for_metrics((generated_audios, input_ids, prompts))
generated_audios, input_ids, prompts = accelerator.gather_for_metrics((generated_audios, input_ids, prompts))
eval_preds.extend(generated_audios)
eval_descriptions.extend(input_ids)
eval_prompts.extend(prompts)
......@@ -1494,6 +1494,14 @@ def main():
epoch=epoch,
prefix="eval",
)
# release eval batch and relax metrics
eval_metrics = []
eval_preds = []
eval_descriptions = []
eval_prompts = []
batch = release_memory(batch)
# flush the train metrics
train_start = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment