Unverified Commit 17efc806 authored by charon____'s avatar charon____ Committed by GitHub
Browse files

IterableDatasetShard should use per device batch size instead of real batch size (#14714)

parent 2a56edb3
......@@ -653,7 +653,7 @@ class Trainer:
return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
......@@ -722,7 +722,7 @@ class Trainer:
if self.args.world_size > 1:
eval_dataset = IterableDatasetShard(
eval_dataset,
batch_size=self.args.eval_batch_size,
batch_size=self.args.per_device_eval_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment