"tests/vscode:/vscode.git/clone" did not exist on "03a3becc48f14a481b578c4d1c02273da9a1cc81"
Unverified Commit 6e1ee47b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Support for set_epoch (#11258)

parent c3fcba32
......@@ -191,9 +191,15 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
training in a distributed fashion, your iterable dataset should either use a internal attribute
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identic on all
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
......@@ -1095,6 +1101,8 @@ class Trainer:
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch)
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
......
......@@ -598,8 +598,8 @@ class IterableDatasetShard(IterableDataset):
:obj:`dataset` to generate your random numbers and call the
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the
seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration.
Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom
logic.
Alternatively, you can also implement a :obj:`set_epoch()` method in your iterable dataset to deal with this.
Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
......@@ -637,9 +637,15 @@ class IterableDatasetShard(IterableDataset):
def set_epoch(self, epoch):
self.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def __iter__(self):
if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator):
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
and isinstance(self.dataset.generator, torch.Generator)
):
self.dataset.generator.manual_seed(self.seed + self.epoch)
real_batch_size = self.batch_size * self.num_processes
process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment