Commit d112db94 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

add group by length sampler

parent 6f5a0277
......@@ -55,6 +55,7 @@ from transformers import (
Seq2SeqTrainingArguments,
)
from transformers.trainer_utils import is_main_process
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry
......@@ -1200,7 +1201,8 @@ def main():
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
sampler = LengthGroupedSampler(per_device_train_batch_size, lengths = vectorized_datasets["train"]["target_length"])
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
......@@ -1343,6 +1345,7 @@ def main():
vectorized_datasets["train"],
collate_fn=data_collator,
batch_size=per_device_train_batch_size,
sampler=sampler,
num_workers=training_args.dataloader_num_workers,
pin_memory=training_args.dataloader_pin_memory,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment