Unverified Commit 699541c4 authored by Setu Shah's avatar Setu Shah Committed by GitHub
Browse files

TFTrainer: Add dataloader_drop_last (#4925)

parent e80d6c68
...@@ -286,6 +286,7 @@ class Trainer: ...@@ -286,6 +286,7 @@ class Trainer:
sampler=sampler, sampler=sampler,
batch_size=self.args.eval_batch_size, batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator.collate_batch, collate_fn=self.data_collator.collate_batch,
drop_last=self.args.dataloader_drop_last,
) )
return data_loader return data_loader
......
...@@ -68,7 +68,7 @@ class TFTrainer: ...@@ -68,7 +68,7 @@ class TFTrainer:
ds = ( ds = (
self.train_dataset.cache() self.train_dataset.cache()
.shuffle(self.num_train_examples) .shuffle(self.num_train_examples)
.batch(self.args.train_batch_size) .batch(self.args.train_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE) .prefetch(tf.data.experimental.AUTOTUNE)
) )
...@@ -82,12 +82,16 @@ class TFTrainer: ...@@ -82,12 +82,16 @@ class TFTrainer:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
ds = eval_dataset.cache().batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE) ds = (
eval_dataset.cache()
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
.prefetch(tf.data.experimental.AUTOTUNE)
)
return self.args.strategy.experimental_distribute_dataset(ds) return self.args.strategy.experimental_distribute_dataset(ds)
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset: def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
ds = test_dataset.batch(self.args.eval_batch_size) ds = test_dataset.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
return self.args.strategy.experimental_distribute_dataset(ds) return self.args.strategy.experimental_distribute_dataset(ds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment