Unverified Commit 17efc806 authored by charon____'s avatar charon____ Committed by GitHub
Browse files

IterableDatasetShard should use per device batch size instead of real batch size (#14714)

parent 2a56edb3
...@@ -653,7 +653,7 @@ class Trainer: ...@@ -653,7 +653,7 @@ class Trainer:
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=self.args.train_batch_size, batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator, collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers, num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory, pin_memory=self.args.dataloader_pin_memory,
...@@ -722,7 +722,7 @@ class Trainer: ...@@ -722,7 +722,7 @@ class Trainer:
if self.args.world_size > 1: if self.args.world_size > 1:
eval_dataset = IterableDatasetShard( eval_dataset = IterableDatasetShard(
eval_dataset, eval_dataset,
batch_size=self.args.eval_batch_size, batch_size=self.args.per_device_eval_batch_size,
drop_last=self.args.dataloader_drop_last, drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size, num_processes=self.args.world_size,
process_index=self.args.process_index, process_index=self.args.process_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment