Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6e1ee47b
"tests/vscode:/vscode.git/clone" did not exist on "03a3becc48f14a481b578c4d1c02273da9a1cc81"
Unverified
Commit
6e1ee47b
authored
Apr 15, 2021
by
Sylvain Gugger
Committed by
GitHub
Apr 15, 2021
Browse files
Support for set_epoch (#11258)
parent
c3fcba32
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
4 deletions
+18
-4
src/transformers/trainer.py
src/transformers/trainer.py
+9
-1
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+9
-3
No files found.
src/transformers/trainer.py
View file @
6e1ee47b
...
...
@@ -191,9 +191,15 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
train_dataset (:obj:`torch.utils.data.dataset.Dataset`
or :obj:`torch.utils.data.dataset.IterableDataset`
, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
training in a distributed fashion, your iterable dataset should either use a internal attribute
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identic on all
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
...
...
@@ -1095,6 +1101,8 @@ class Trainer:
for
epoch
in
range
(
epochs_trained
,
num_train_epochs
):
if
isinstance
(
train_dataloader
,
DataLoader
)
and
isinstance
(
train_dataloader
.
sampler
,
DistributedSampler
):
train_dataloader
.
sampler
.
set_epoch
(
epoch
)
elif
isinstance
(
train_dataloader
.
dataset
,
IterableDatasetShard
):
train_dataloader
.
dataset
.
set_epoch
(
epoch
)
if
is_torch_tpu_available
():
parallel_loader
=
pl
.
ParallelLoader
(
train_dataloader
,
[
self
.
args
.
device
]).
per_device_loader
(
...
...
src/transformers/trainer_pt_utils.py
View file @
6e1ee47b
...
...
@@ -598,8 +598,8 @@ class IterableDatasetShard(IterableDataset):
:obj:`dataset` to generate your random numbers and call the
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the
seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration.
Alternatively, you can also
subclass this class and override the :meth:`__iter__
` method
with
your
custom
logic.
Alternatively, you can also
implement a :obj:`set_epoch()
` method
in
your
iterable dataset to deal with this.
Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
...
...
@@ -637,9 +637,15 @@ class IterableDatasetShard(IterableDataset):
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
if
hasattr
(
self
.
dataset
,
"set_epoch"
):
self
.
dataset
.
set_epoch
(
epoch
)
def
__iter__
(
self
):
if
hasattr
(
self
.
dataset
,
"generator"
)
and
isinstance
(
self
.
dataset
.
generator
,
torch
.
Generator
):
if
(
not
hasattr
(
self
.
dataset
,
"set_epoch"
)
and
hasattr
(
self
.
dataset
,
"generator"
)
and
isinstance
(
self
.
dataset
.
generator
,
torch
.
Generator
)
):
self
.
dataset
.
generator
.
manual_seed
(
self
.
seed
+
self
.
epoch
)
real_batch_size
=
self
.
batch_size
*
self
.
num_processes
process_slice
=
range
(
self
.
process_index
*
self
.
batch_size
,
(
self
.
process_index
+
1
)
*
self
.
batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment