Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2cd02ac
Unverified
Commit
c2cd02ac
authored
Apr 30, 2021
by
Takuya Makino
Committed by
GitHub
Apr 30, 2021
Browse files
Accepts BatchEncoding in LengthSampler (#11431)
parent
30ede899
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
2 deletions
+40
-2
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+9
-2
tests/test_trainer_utils.py
tests/test_trainer_utils.py
+31
-0
No files found.
src/transformers/trainer_pt_utils.py
View file @
c2cd02ac
...
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
from
torch.utils.data.sampler
import
RandomSampler
,
Sampler
from
.file_utils
import
is_sagemaker_dp_enabled
,
is_sagemaker_mp_enabled
,
is_torch_tpu_available
from
.file_utils
import
is_sagemaker_dp_enabled
,
is_sagemaker_mp_enabled
,
is_torch_tpu_available
from
.tokenization_utils_base
import
BatchEncoding
from
.utils
import
logging
from
.utils
import
logging
...
@@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler):
...
@@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
"input_ids"
self
.
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
"input_ids"
if
lengths
is
None
:
if
lengths
is
None
:
if
not
isinstance
(
dataset
[
0
],
dict
)
or
self
.
model_input_name
not
in
dataset
[
0
]:
if
(
not
(
isinstance
(
dataset
[
0
],
dict
)
or
isinstance
(
dataset
[
0
],
BatchEncoding
))
or
self
.
model_input_name
not
in
dataset
[
0
]
):
raise
ValueError
(
raise
ValueError
(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f
"'
{
self
.
model_input_name
}
' key."
f
"'
{
self
.
model_input_name
}
' key."
...
@@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler):
...
@@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler):
self
.
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
"input_ids"
self
.
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
"input_ids"
if
lengths
is
None
:
if
lengths
is
None
:
if
not
isinstance
(
dataset
[
0
],
dict
)
or
self
.
model_input_name
not
in
dataset
[
0
]:
if
(
not
(
isinstance
(
dataset
[
0
],
dict
)
or
isinstance
(
dataset
[
0
],
BatchEncoding
))
or
self
.
model_input_name
not
in
dataset
[
0
]
):
raise
ValueError
(
raise
ValueError
(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f
"'
{
self
.
model_input_name
}
' key."
f
"'
{
self
.
model_input_name
}
' key."
...
...
tests/test_trainer_utils.py
View file @
c2cd02ac
...
@@ -27,6 +27,7 @@ if is_torch_available():
...
@@ -27,6 +27,7 @@ if is_torch_available():
from
torch.utils.data
import
IterableDataset
from
torch.utils.data
import
IterableDataset
from
transformers.modeling_outputs
import
SequenceClassifierOutput
from
transformers.modeling_outputs
import
SequenceClassifierOutput
from
transformers.tokenization_utils_base
import
BatchEncoding
from
transformers.trainer_pt_utils
import
(
from
transformers.trainer_pt_utils
import
(
DistributedLengthGroupedSampler
,
DistributedLengthGroupedSampler
,
DistributedSamplerWithLoop
,
DistributedSamplerWithLoop
,
...
@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase):
...
@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase):
# The indices should be a permutation of range(100)
# The indices should be a permutation of range(100)
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
100
)))
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
100
)))
def
test_group_by_length_with_dict
(
self
):
# Get some inputs of random lengths
data
=
[]
for
_
in
range
(
6
):
input_ids
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
data
.
append
({
"input_ids"
:
input_ids
})
# Put one bigger than the others to check it ends up in first position
data
[
3
][
"input_ids"
]
=
torch
.
randint
(
0
,
25
,
(
105
,)).
tolist
()
indices
=
list
(
LengthGroupedSampler
(
data
,
4
))
# The biggest element should be first
self
.
assertEqual
(
len
(
data
[
indices
[
0
]][
"input_ids"
]),
105
)
# The indices should be a permutation of range(6)
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
6
)))
def
test_group_by_length_with_batch_encoding
(
self
):
# Get some inputs of random lengths
data
=
[]
for
_
in
range
(
6
):
input_ids
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
data
.
append
(
BatchEncoding
({
"input_ids"
:
input_ids
}))
# Put one bigger than the others to check it ends up in first position
data
[
3
][
"input_ids"
]
=
torch
.
randint
(
0
,
25
,
(
105
,)).
tolist
()
indices
=
list
(
LengthGroupedSampler
(
data
,
4
))
# The biggest element should be first
self
.
assertEqual
(
len
(
data
[
indices
[
0
]][
"input_ids"
]),
105
)
# The indices should be a permutation of range(6)
self
.
assertEqual
(
list
(
sorted
(
indices
)),
list
(
range
(
6
)))
def
test_distributed_length_grouped
(
self
):
def
test_distributed_length_grouped
(
self
):
# Get some inputs of random lengths
# Get some inputs of random lengths
lengths
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
lengths
=
torch
.
randint
(
0
,
25
,
(
100
,)).
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment