Commit 0e1869cc authored by Setu Shah's avatar Setu Shah Committed by Julien Chaumond
Browse files

Add drop_last arg for data loader

parent 48a05026
...@@ -240,6 +240,7 @@ class Trainer: ...@@ -240,6 +240,7 @@ class Trainer:
batch_size=self.args.train_batch_size, batch_size=self.args.train_batch_size,
sampler=train_sampler, sampler=train_sampler,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
drop_last=self.args.dataloader_drop_last,
) )
return data_loader return data_loader
...@@ -264,6 +265,7 @@ class Trainer: ...@@ -264,6 +265,7 @@ class Trainer:
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
drop_last=self.args.dataloader_drop_last,
) )
return data_loader return data_loader
......
...@@ -133,6 +133,10 @@ class TrainingArguments: ...@@ -133,6 +133,10 @@ class TrainingArguments:
) )
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"}) tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"})
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
if self.per_gpu_train_batch_size: if self.per_gpu_train_batch_size:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment