Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2cd02ac
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4fc708f98c9c8d5cb48e8a2639e3f7a21c65802f"
Unverified
Commit
c2cd02ac
authored
Apr 30, 2021
by
Takuya Makino
Committed by
GitHub
Apr 30, 2021
Browse files
Accepts BatchEncoding in LengthSampler (#11431)
parent
30ede899
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
2 deletions
+40
-2
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+9
-2
tests/test_trainer_utils.py
tests/test_trainer_utils.py
+31
-0
No files found.
src/transformers/trainer_pt_utils.py
View file @
c2cd02ac
...
...
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
from
.file_utils
import
is_sagemaker_dp_enabled
,
is_sagemaker_mp_enabled
,
is_torch_tpu_available
from
.tokenization_utils_base
import
BatchEncoding
from
.utils
import
logging
...
...
@@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler):
self
.
batch_size
=
batch_size
self
.
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
"input_ids"
if
lengths
is
None
:
if
not
isinstance
(
dataset
[
0
],
dict
)
or
self
.
model_input_name
not
in
dataset
[
0
]:
if
(
not
(
isinstance
(
dataset
[
0
],
dict
)
or
isinstance
(
dataset
[
0
],
BatchEncoding
))
or
self
.
model_input_name
not
in
dataset
[
0
]
):
raise
ValueError
(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f
"'
{
self
.
model_input_name
}
' key."
...
...
@@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler):
self
.
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
"input_ids"
if
lengths
is
None
:
if
not
isinstance
(
dataset
[
0
],
dict
)
or
self
.
model_input_name
not
in
dataset
[
0
]:
if
(
not
(
isinstance
(
dataset
[
0
],
dict
)
or
isinstance
(
dataset
[
0
],
BatchEncoding
))
or
self
.
model_input_name
not
in
dataset
[
0
]
):
raise
ValueError
(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f
"'
{
self
.
model_input_name
}
' key."
...
...
tests/test_trainer_utils.py
View file @
c2cd02ac
...
...
@@ -27,6 +27,7 @@ if is_torch_available():
from
torch.utils.data
import
IterableDataset
from
transformers.modeling_outputs
import
SequenceClassifierOutput
from
transformers.tokenization_utils_base
import
BatchEncoding
from
transformers.trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
...
...
@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase):
# The indices should be a permutation of range(100)
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
100
)))
def
test_group_by_length_with_dict
(
self
):
# Get some inputs of random lengths
data
=
[]
for
_
in
range
(
6
):
input_ids
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
data
.
append
({
"input_ids"
:
input_ids
})
# Put one bigger than the others to check it ends up in first position
data
[
3
][
"input_ids"
]
=
torch
.
randint
(
0
,
25
,
(
105
,)).
tolist
()
indices
=
list
(
LengthGroupedSampler
(
data
,
4
))
# The biggest element should be first
self
.
assertEqual
(
len
(
data
[
indices
[
0
]][
"input_ids"
]),
105
)
# The indices should be a permutation of range(6)
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
6
)))
def
test_group_by_length_with_batch_encoding
(
self
):
# Get some inputs of random lengths
data
=
[]
for
_
in
range
(
6
):
input_ids
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
data
.
append
(
BatchEncoding
({
"input_ids"
:
input_ids
}))
# Put one bigger than the others to check it ends up in first position
data
[
3
][
"input_ids"
]
=
torch
.
randint
(
0
,
25
,
(
105
,)).
tolist
()
indices
=
list
(
LengthGroupedSampler
(
data
,
4
))
# The biggest element should be first
self
.
assertEqual
(
len
(
data
[
indices
[
0
]][
"input_ids"
]),
105
)
# The indices should be a permutation of range(6)
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
6
)))
def
test_distributed_length_grouped
(
self
):
# Get some inputs of random lengths
lengths
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment