Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
571369aa
Commit
571369aa
authored
Jul 08, 2020
by
A. Unique TensorFlower
Browse files
Add STS-B glue data preprocessor and regression support, and alpha-order the preprocessor classes.
PiperOrigin-RevId: 320318976
parent
7bb5ab6d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
342 additions
and
285 deletions
+342
-285
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+338
-284
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+4
-1
No files found.
official/nlp/data/classifier_data_lib.py
View file @
571369aa
...
@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
...
@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
class
InputExample
(
object
):
class
InputExample
(
object
):
"""A single training/test example for simple seq
uence
classification."""
"""A single training/test example for simple seq
regression/
classification."""
def
__init__
(
self
,
def
__init__
(
self
,
guid
,
guid
,
...
@@ -48,8 +48,9 @@ class InputExample(object):
...
@@ -48,8 +48,9 @@ class InputExample(object):
sequence tasks, only this sequence must be specified.
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
label: (Optional) string for classification, float for regression. The
specified for train and dev examples, but not for test examples.
label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
weight: (Optional) float. The weight of the example to be used during
training.
training.
int_iden: (Optional) int. The int identification number of example in the
int_iden: (Optional) int. The int identification number of example in the
...
@@ -84,10 +85,12 @@ class InputFeatures(object):
...
@@ -84,10 +85,12 @@ class InputFeatures(object):
class
DataProcessor
(
object
):
class
DataProcessor
(
object
):
"""Base class for
data
converters for seq
uence
classification data
sets."""
"""Base class for converters for seq
regression/
classification datasets."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
self
.
process_text_fn
=
process_text_fn
self
.
process_text_fn
=
process_text_fn
self
.
is_regression
=
False
self
.
label_type
=
None
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
"""Gets a collection of `InputExample`s for the train set."""
...
@@ -121,76 +124,70 @@ class DataProcessor(object):
...
@@ -121,76 +124,70 @@ class DataProcessor(object):
return
lines
return
lines
class
XnliProcessor
(
DataProcessor
):
class
ColaProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
"""Processor for the CoLA data set (GLUE version)."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
language
=
"en"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
XnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
language
==
"all"
:
self
.
languages
=
XnliProcessor
.
supported_languages
elif
language
not
in
XnliProcessor
.
supported_languages
:
raise
ValueError
(
"language %s is not supported for XNLI task."
%
language
)
else
:
self
.
languages
=
[
language
]
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
[]
return
self
.
_create_examples
(
for
language
in
self
.
languages
:
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
# Skips the header.
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
# Only the test set has a header
if
set_type
==
"test"
and
i
==
0
:
continue
continue
guid
=
"dev-%d"
%
i
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
6
])
if
set_type
==
"test"
:
text_b
=
self
.
process_text_fn
(
line
[
7
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
else
:
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
return
examples
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
return
self
.
_create_examples
(
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
language
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples_by_lang
[
language
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -199,65 +196,69 @@ class XnliProcessor(DataProcessor):
...
@@ -199,65 +196,69 @@ class XnliProcessor(DataProcessor):
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"XNLI"
return
"MNLI"
class
XtremeXnliProcessor
(
DataProcessor
):
"""Processor for the XTREME XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
if
i
==
0
:
text_a
=
self
.
process_text_fn
(
line
[
0
])
continue
text_b
=
self
.
process_text_fn
(
line
[
1
])
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
label
=
self
.
process_text_fn
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
return
self
.
_create_examples
(
examples
=
[]
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
return
self
.
_create_examples
(
for
lang
in
self
.
supported_languages
:
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"
contradiction"
,
"entailment"
,
"neutral
"
]
return
[
"
0"
,
"1
"
]
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"XTREME-XNLI"
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
text_b
=
self
.
process_text_fn
(
line
[
4
])
if
set_type
==
"test"
:
label
=
"0"
else
:
label
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
PawsxProcessor
(
DataProcessor
):
class
PawsxProcessor
(
DataProcessor
):
...
@@ -339,154 +340,8 @@ class PawsxProcessor(DataProcessor):
...
@@ -339,154 +340,8 @@ class PawsxProcessor(DataProcessor):
return
"XTREME-PAWS-X"
return
"XTREME-PAWS-X"
class
XtremePawsxProcessor
(
DataProcessor
):
class
QnliProcessor
(
DataProcessor
):
"""Processor for the XTREME PAWS-X data set."""
"""Processor for the QNLI data set (GLUE version)."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-PAWS-X"
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
text_b
=
self
.
process_text_fn
(
line
[
4
])
if
set_type
==
"test"
:
label
=
"0"
else
:
label
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
QqpProcessor
(
DataProcessor
):
"""Processor for the QQP data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -496,7 +351,7 @@ class QqpProcessor(DataProcessor):
...
@@ -496,7 +351,7 @@ class QqpProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev
_matched
"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -505,12 +360,12 @@ class QqpProcessor(DataProcessor):
...
@@ -505,12 +360,12 @@ class QqpProcessor(DataProcessor):
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"
0"
,
"1
"
]
return
[
"
entailment"
,
"not_entailment
"
]
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"Q
QP
"
return
"Q
NLI
"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training and dev sets."""
...
@@ -518,20 +373,22 @@ class QqpProcessor(DataProcessor):
...
@@ -518,20 +373,22 @@ class QqpProcessor(DataProcessor):
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
guid
=
"%s-%s"
%
(
set_type
,
1
)
try
:
if
set_type
==
"test"
:
text_a
=
line
[
3
]
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
line
[
4
]
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
line
[
5
]
label
=
"entailment"
except
IndexError
:
else
:
continue
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
class
Cola
Processor
(
DataProcessor
):
class
Qqp
Processor
(
DataProcessor
):
"""Processor for the
CoLA
data set (GLUE version)."""
"""Processor for the
QQP
data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -555,24 +412,23 @@ class ColaProcessor(DataProcessor):
...
@@ -555,24 +412,23 @@ class ColaProcessor(DataProcessor):
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"
COLA
"
return
"
QQP
"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
# Only the test set has a header
if
i
==
0
:
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
try
:
text_a
=
line
[
3
]
text_b
=
line
[
4
]
label
=
line
[
5
]
except
IndexError
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
else
:
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
...
@@ -668,8 +524,14 @@ class SstProcessor(DataProcessor):
...
@@ -668,8 +524,14 @@ class SstProcessor(DataProcessor):
return
examples
return
examples
class
QnliProcessor
(
DataProcessor
):
class
StsBProcessor
(
DataProcessor
):
"""Processor for the QNLI data set (GLUE version)."""
"""Processor for the STS-B data set (GLUE version)."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
StsBProcessor
,
self
).
__init__
(
process_text_fn
=
process_text_fn
)
self
.
is_regression
=
True
self
.
label_type
=
float
self
.
_labels
=
None
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -679,7 +541,7 @@ class QnliProcessor(DataProcessor):
...
@@ -679,7 +541,7 @@ class QnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev
_matched
"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -688,28 +550,26 @@ class QnliProcessor(DataProcessor):
...
@@ -688,28 +550,26 @@ class QnliProcessor(DataProcessor):
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
return
self
.
_labels
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"
QNLI
"
return
"
STS-B
"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
7
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
8
])
if
set_type
==
"test"
:
if
set_type
==
"test"
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
0.0
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
"entailment"
else
:
else
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
self
.
label_type
(
tokenization
.
convert_to_unicode
(
line
[
9
]))
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
...
@@ -888,6 +748,200 @@ class WnliProcessor(DataProcessor):
...
@@ -888,6 +748,200 @@ class WnliProcessor(DataProcessor):
return
examples
return
examples
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
language
=
"en"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
XnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
language
==
"all"
:
self
.
languages
=
XnliProcessor
.
supported_languages
elif
language
not
in
XnliProcessor
.
supported_languages
:
raise
ValueError
(
"language %s is not supported for XNLI task."
%
language
)
else
:
self
.
languages
=
[
language
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
[]
for
language
in
self
.
languages
:
# Skips the header.
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
language
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples_by_lang
[
language
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XNLI"
class
XtremePawsxProcessor
(
DataProcessor
):
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-PAWS-X"
class
XtremeXnliProcessor
(
DataProcessor
):
"""Processor for the XTREME XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-XNLI"
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
):
tokenizer
):
"""Converts a single `InputExample` into a single `InputFeatures`."""
"""Converts a single `InputExample` into a single `InputFeatures`."""
...
...
official/nlp/data/create_finetuning_data.py
View file @
571369aa
...
@@ -51,7 +51,8 @@ flags.DEFINE_string(
...
@@ -51,7 +51,8 @@ flags.DEFINE_string(
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"SST-2"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"The name of the task to train BERT classifier. The "
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"of input tsv files; 2. the dev set for XTREME is english "
...
@@ -187,6 +188,8 @@ def generate_classifier_dataset():
...
@@ -187,6 +188,8 @@ def generate_classifier_dataset():
"rte"
:
classifier_data_lib
.
RteProcessor
,
"rte"
:
classifier_data_lib
.
RteProcessor
,
"sst-2"
:
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
classifier_data_lib
.
SstProcessor
,
"sts-b"
:
classifier_data_lib
.
StsBProcessor
,
"xnli"
:
"xnli"
:
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
language
=
FLAGS
.
xnli_language
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment