Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3a5d1ea2
Unverified
Commit
3a5d1ea2
authored
May 29, 2020
by
Zhangyx
Committed by
GitHub
May 29, 2020
Browse files
Fix two bugs: 1. Index of test data of SST-2. 2. Label index of MNLI data. (#4546)
parent
9c172564
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
10 deletions
+13
-10
src/transformers/data/datasets/glue.py
src/transformers/data/datasets/glue.py
+11
-9
src/transformers/data/processors/glue.py
src/transformers/data/processors/glue.py
+2
-1
No files found.
src/transformers/data/datasets/glue.py
View file @
3a5d1ea2
...
...
@@ -86,6 +86,15 @@ class GlueDataset(Dataset):
mode
.
value
,
tokenizer
.
__class__
.
__name__
,
str
(
args
.
max_seq_length
),
args
.
task_name
,
),
)
label_list
=
self
.
processor
.
get_labels
()
if
args
.
task_name
in
[
"mnli"
,
"mnli-mm"
]
and
tokenizer
.
__class__
in
(
RobertaTokenizer
,
RobertaTokenizerFast
,
XLMRobertaTokenizer
,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list
[
1
],
label_list
[
2
]
=
label_list
[
2
],
label_list
[
1
]
self
.
label_list
=
label_list
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
...
...
@@ -100,14 +109,7 @@ class GlueDataset(Dataset):
)
else
:
logger
.
info
(
f
"Creating features from dataset file at
{
args
.
data_dir
}
"
)
label_list
=
self
.
processor
.
get_labels
()
if
args
.
task_name
in
[
"mnli"
,
"mnli-mm"
]
and
tokenizer
.
__class__
in
(
RobertaTokenizer
,
RobertaTokenizerFast
,
XLMRobertaTokenizer
,
):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list
[
1
],
label_list
[
2
]
=
label_list
[
2
],
label_list
[
1
]
if
mode
==
Split
.
dev
:
examples
=
self
.
processor
.
get_dev_examples
(
args
.
data_dir
)
elif
mode
==
Split
.
test
:
...
...
@@ -137,4 +139,4 @@ class GlueDataset(Dataset):
return
self
.
features
[
i
]
def
get_labels
(
self
):
return
self
.
processor
.
get_labels
()
return
self
.
label_list
src/transformers/data/processors/glue.py
View file @
3a5d1ea2
...
...
@@ -332,11 +332,12 @@ class Sst2Processor(DataProcessor):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training, dev and test sets."""
examples
=
[]
text_index
=
1
if
set_type
==
"test"
else
0
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
line
[
0
]
text_a
=
line
[
text_index
]
label
=
None
if
set_type
==
"test"
else
line
[
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment