Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
043c8781
Commit
043c8781
authored
Mar 14, 2019
by
Ananya Harsh Jha
Browse files
added code for all glue task processors
parent
9b03d67b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
191 additions
and
0 deletions
+191
-0
examples/run_classifier.py
examples/run_classifier.py
+191
-0
No files found.
examples/run_classifier.py
View file @
043c8781
...
...
@@ -167,6 +167,16 @@ class MnliProcessor(DataProcessor):
return
examples
class
MnliMismatchedProcessor
(
MnliProcessor
):
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_matched"
)
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
...
...
@@ -227,6 +237,170 @@ class Sst2Processor(DataProcessor):
return
examples
class
StsbProcessor
(
DataProcessor
):
"""Processor for the STS-B data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
return
[
None
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
text_a
=
line
[
7
]
text_b
=
line
[
8
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
QqpProcessor
(
DataProcessor
):
"""Processor for the STS-B data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
try
:
text_a
=
line
[
3
]
text_b
=
line
[
4
]
label
=
line
[
5
]
except
IndexError
:
continue
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
QnliProcessor
(
DataProcessor
):
"""Processor for the STS-B data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev_matched"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
RteProcessor
(
DataProcessor
):
"""Processor for the RTE data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
WnliProcessor
(
DataProcessor
):
"""Processor for the WNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
line
[
-
1
]
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
):
"""Loads a data file into a list of `InputBatch`s."""
...
...
@@ -433,6 +607,23 @@ def main():
"mnli"
:
MnliProcessor
,
"mrpc"
:
MrpcProcessor
,
"sst-2"
:
Sst2Processor
,
"sts-b"
:
StsbProcessor
,
"qqp"
:
QqpProcessor
,
"qnli"
:
QnliProcessor
,
"rte"
:
RteProcessor
,
"wnli"
:
WnliProcessor
,
}
output_modes
=
{
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mrpc"
:
"classification"
,
"sst-2"
:
"classification"
,
"sts-b"
:
"regression"
,
"qqp"
:
"classification"
,
"qnli"
:
"classification"
,
"rte"
:
"classification"
,
"wnli"
:
"classification"
,
}
num_labels_task
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment