Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5f721ad6
Unverified
Commit
5f721ad6
authored
Jun 18, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 18, 2020
Browse files
Fix #5114 (#5122)
parent
a258982a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
4 deletions
+16
-4
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+2
-2
tests/test_trainer.py
tests/test_trainer.py
+14
-2
No files found.
src/transformers/data/data_collator.py
View file @
5f721ad6
...
@@ -42,10 +42,10 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
...
@@ -42,10 +42,10 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# Special handling for labels.
# Special handling for labels.
# Ensure that tensor is created with the correct type
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
# (it should be automatically the case, but let's make sure of it.)
if
"label"
in
first
:
if
"label"
in
first
and
first
[
"label"
]
is
not
None
:
dtype
=
torch
.
long
if
type
(
first
[
"label"
])
is
int
else
torch
.
float
dtype
=
torch
.
long
if
type
(
first
[
"label"
])
is
int
else
torch
.
float
batch
[
"labels"
]
=
torch
.
tensor
([
f
[
"label"
]
for
f
in
features
],
dtype
=
dtype
)
batch
[
"labels"
]
=
torch
.
tensor
([
f
[
"label"
]
for
f
in
features
],
dtype
=
dtype
)
elif
"label_ids"
in
first
:
elif
"label_ids"
in
first
and
first
[
"label_ids"
]
is
not
None
:
if
isinstance
(
first
[
"label_ids"
],
torch
.
Tensor
):
if
isinstance
(
first
[
"label_ids"
],
torch
.
Tensor
):
batch
[
"labels"
]
=
torch
.
stack
([
f
[
"label_ids"
]
for
f
in
features
])
batch
[
"labels"
]
=
torch
.
stack
([
f
[
"label_ids"
]
for
f
in
features
])
else
:
else
:
...
...
tests/test_trainer.py
View file @
5f721ad6
...
@@ -25,7 +25,7 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
...
@@ -25,7 +25,7 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@
require_torch
@
require_torch
class
DataCollatorIntegrationTest
(
unittest
.
TestCase
):
class
DataCollatorIntegrationTest
(
unittest
.
TestCase
):
def
test_default_with_dict
(
self
):
def
test_default_with_dict
(
self
):
features
=
[{
"label
s
"
:
i
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
features
=
[{
"label"
:
i
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
...
@@ -39,12 +39,24 @@ class DataCollatorIntegrationTest(unittest.TestCase):
...
@@ -39,12 +39,24 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
# Features can already be tensors
# Features can already be tensors
features
=
[{
"label
s
"
:
i
,
"inputs"
:
torch
.
randint
(
10
,
[
10
])}
for
i
in
range
(
8
)]
features
=
[{
"label"
:
i
,
"inputs"
:
torch
.
randint
(
10
,
[
10
])}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertTrue
(
batch
[
"labels"
].
equal
(
torch
.
tensor
(
list
(
range
(
8
)))))
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
self
.
assertEqual
(
batch
[
"labels"
].
dtype
,
torch
.
long
)
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
10
]))
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
10
]))
def
test_default_with_no_labels
(
self
):
features
=
[{
"label"
:
None
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
"labels"
not
in
batch
)
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
# With label_ids
features
=
[{
"label_ids"
:
None
,
"inputs"
:
[
0
,
1
,
2
,
3
,
4
,
5
]}
for
i
in
range
(
8
)]
batch
=
default_data_collator
(
features
)
self
.
assertTrue
(
"labels"
not
in
batch
)
self
.
assertEqual
(
batch
[
"inputs"
].
shape
,
torch
.
Size
([
8
,
6
]))
def
test_default_classification
(
self
):
def
test_default_classification
(
self
):
MODEL_ID
=
"bert-base-cased-finetuned-mrpc"
MODEL_ID
=
"bert-base-cased-finetuned-mrpc"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_ID
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_ID
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment