Unverified Commit 699541c4 authored by Setu Shah's avatar Setu Shah Committed by GitHub
Browse files

TFTrainer: Add dataloader_drop_last (#4925)

parent e80d6c68
......@@ -286,6 +286,7 @@ class Trainer:
sampler=sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch,
drop_last=self.args.dataloader_drop_last,
)
return data_loader
......
......@@ -68,7 +68,7 @@ class TFTrainer:
ds = (
self.train_dataset.cache()
.shuffle(self.num_train_examples)
.batch(self.args.train_batch_size)
.batch(self.args.train_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
......@@ -82,12 +82,16 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
ds = eval_dataset.cache().batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
ds = (
eval_dataset.cache()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
return self.args.strategy.experimental_distribute_dataset(ds)
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
ds = test_dataset.batch(self.args.eval_batch_size)
ds = test_dataset.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
return self.args.strategy.experimental_distribute_dataset(ds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment