Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5f721ad6
Unverified
Commit
5f721ad6
authored
Jun 18, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 18, 2020
Browse files
Fix #5114 (#5122)
parent
a258982a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
4 deletions
+16
-4
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+2
-2
tests/test_trainer.py
tests/test_trainer.py
+14
-2
No files found.
src/transformers/data/data_collator.py
View file @
5f721ad6
...
@@ -42,10 +42,10 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
...
@@ -42,10 +42,10 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# Special handling for labels.
# Special handling for labels.
# Ensure that tensor is created with the correct type
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
# (it should be automatically the case, but let's make sure of it.)
if
"label"
in
first
:
if
"label"
in
first
and
first
[
"label"
]
is
not
None
:
dtype
=
torch
.
long
if
type
(
first
[
"label"
])
is
int
else
torch
.
float
dtype
=
torch
.
long
if
type
(
first
[
"label"
])
is
int
else
torch
.
float
batch
[
"labels"
]
=
torch
.
tensor
([
f
[
"label"
]
for
f
in
features
],
dtype
=
dtype
)
batch
[
"labels"
]
=
torch
.
tensor
([
f
[
"label"
]
for
f
in
features
],
dtype
=
dtype
)
elif
"label_ids"
in
first
:
elif
"label_ids"
in
first
and
first
[
"label_ids"
]
is
not
None
:
if
isinstance
(
first
[
"label_ids"
],
torch
.
Tensor
):
if
isinstance
(
first
[
"label_ids"
],
torch
.
Tensor
):
batch
[
"labels"
]
=
torch
.
stack
([
f
[
"label_ids"
]
for
f
in
features
])
batch
[
"labels"
]
=
torch
.
stack
([
f
[
"label_ids"
]
for
f
in
features
])
else
:
else
:
...
...
tests/test_trainer.py
View file @
5f721ad6
...
@@ -25,7 +25,7 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
...
@@ -25,7 +25,7 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@
require_torch
@
require_torch
class
DataCollatorIntegrationTest
(
unittest
.
TestCase
):
class
DataCollatorIntegrationTest
(
unittest
.
TestCase
):
def
test_default_with_dict
(
self
):
def
test_default_with_dict
(
self
):
features
=
[{
"label
s
"
:
i
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
features
=
[{
"label"
:
i
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
...
@@ -39,12 +39,24 @@ class DataCollatorIntegrationTest(unittest.TestCase):
...
@@ -39,12 +39,24 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
# Features can already be tensors
# Features can already be tensors
features
=
[{
"label
s
"
:
i
,
"inputs"
:
torch
.
randint
(
10
,
[
10
])}
for
i
in
range
(
8
)]
features
=
[{
"label"
:
i
,
"inputs"
:
torch
.
randint
(
10
,
[
10
])}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
10
]))
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
10
]))
def
test_default_with_no_labels
(
self
):
features
=
[{
"label"
:
None
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
"labels"
not
in
batch
)
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
# With label_ids
features
=
[{
"label_ids"
:
None
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
"labels"
not
in
batch
)
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
def
test_default_classification
(
self
):
def
test_default_classification
(
self
):
MODEL_ID
=
"bert-base-cased-finetuned-mrpc"
MODEL_ID
=
"bert-base-cased-finetuned-mrpc"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_ID
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_ID
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment