Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
63cc5bda
Unverified
Commit
63cc5bda
authored
Sep 29, 2021
by
Sylvain Gugger
Committed by
GitHub
Sep 29, 2021
Browse files
Fix length of IterableDatasetShard and add test (#13792)
* Fix length of IterableDatasetShard and add test * Add comments
parent
7d84c3a4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
2 deletions
+31
-2
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+2
-2
tests/test_trainer_utils.py
tests/test_trainer_utils.py
+28
-0
utils/tests_fetcher.py
utils/tests_fetcher.py
+1
-0
No files found.
src/transformers/trainer_pt_utils.py
View file @
63cc5bda
...
@@ -775,9 +775,9 @@ class IterableDatasetShard(IterableDataset):
...
@@ -775,9 +775,9 @@ class IterableDatasetShard(IterableDataset):
def
__len__
(
self
):
def
__len__
(
self
):
# Will raise an error if the underlying dataset is not sized.
# Will raise an error if the underlying dataset is not sized.
if
self
.
drop_last
:
if
self
.
drop_last
:
return
len
(
self
.
dataset
)
//
self
.
num_processes
return
(
len
(
self
.
dataset
)
//
(
self
.
batch_size
*
self
.
num_processes
))
*
self
.
batch_size
else
:
else
:
return
math
.
ceil
(
len
(
self
.
dataset
)
/
self
.
num_processes
)
return
math
.
ceil
(
len
(
self
.
dataset
)
/
(
self
.
batch_size
*
self
.
num_processes
))
*
self
.
batch_size
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
...
...
tests/test_trainer_utils.py
View file @
63cc5bda
...
@@ -355,6 +355,34 @@ class TrainerUtilsTest(unittest.TestCase):
...
@@ -355,6 +355,34 @@ class TrainerUtilsTest(unittest.TestCase):
self
.
check_iterable_dataset_shard
(
dataset
,
4
,
drop_last
=
True
,
num_processes
=
3
,
epoch
=
42
)
self
.
check_iterable_dataset_shard
(
dataset
,
4
,
drop_last
=
True
,
num_processes
=
3
,
epoch
=
42
)
self
.
check_iterable_dataset_shard
(
dataset
,
4
,
drop_last
=
False
,
num_processes
=
3
,
epoch
=
42
)
self
.
check_iterable_dataset_shard
(
dataset
,
4
,
drop_last
=
False
,
num_processes
=
3
,
epoch
=
42
)
def
test_iterable_dataset_shard_with_length
(
self
):
sampler_shards
=
[
IterableDatasetShard
(
list
(
range
(
100
)),
batch_size
=
4
,
drop_last
=
True
,
num_processes
=
2
,
process_index
=
i
)
for
i
in
range
(
2
)
]
# Build expected shards: each process will have batches of size 4 until there is not enough elements to
# form two full batches (so we stop at 96 = (100 // (4 * 2)) * 4)
expected_shards
=
[[],
[]]
current_shard
=
0
for
i
in
range
(
0
,
96
,
4
):
expected_shards
[
current_shard
].
extend
(
list
(
range
(
i
,
i
+
4
)))
current_shard
=
1
-
current_shard
self
.
assertListEqual
([
list
(
shard
)
for
shard
in
sampler_shards
],
expected_shards
)
self
.
assertListEqual
([
len
(
shard
)
for
shard
in
sampler_shards
],
[
len
(
shard
)
for
shard
in
expected_shards
])
sampler_shards
=
[
IterableDatasetShard
(
list
(
range
(
100
)),
batch_size
=
4
,
drop_last
=
False
,
num_processes
=
2
,
process_index
=
i
)
for
i
in
range
(
2
)
]
# When drop_last=False, we get two last full batches by looping back to the beginning.
expected_shards
[
0
].
extend
(
list
(
range
(
96
,
100
)))
expected_shards
[
1
].
extend
(
list
(
range
(
0
,
4
)))
self
.
assertListEqual
([
list
(
shard
)
for
shard
in
sampler_shards
],
expected_shards
)
self
.
assertListEqual
([
len
(
shard
)
for
shard
in
sampler_shards
],
[
len
(
shard
)
for
shard
in
expected_shards
])
def
check_shard_sampler
(
self
,
dataset
,
batch_size
,
drop_last
,
num_processes
=
2
):
def
check_shard_sampler
(
self
,
dataset
,
batch_size
,
drop_last
,
num_processes
=
2
):
shards
=
[
shards
=
[
ShardSampler
(
ShardSampler
(
...
...
utils/tests_fetcher.py
View file @
63cc5bda
...
@@ -281,6 +281,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
...
@@ -281,6 +281,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
"test_trainer_distributed.py"
,
"test_trainer_distributed.py"
,
"test_trainer_tpu.py"
,
"test_trainer_tpu.py"
,
],
],
"train_pt_utils.py"
:
"test_trainer_utils.py"
,
"utils/versions.py"
:
"test_versions_utils.py"
,
"utils/versions.py"
:
"test_versions_utils.py"
,
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment