Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
52515dc3
Commit
52515dc3
authored
Jul 20, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 322197751
parent
57253ebc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
66 additions
and
43 deletions
+66
-43
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+52
-35
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+14
-8
No files found.
official/nlp/data/classifier_data_lib.py
View file @
52515dc3
...
@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
...
@@ -152,10 +152,10 @@ class ColaProcessor(DataProcessor):
return
"COLA"
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
# Only the test set has a header
# Only the test set has a header
.
if
set_type
==
"test"
and
i
==
0
:
if
set_type
==
"test"
and
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
...
@@ -173,6 +173,14 @@ class ColaProcessor(DataProcessor):
class
MnliProcessor
(
DataProcessor
):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
"""Processor for the MultiNLI data set (GLUE version)."""
def
__init__
(
self
,
mnli_type
=
"matched"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
...
@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
...
@@ -180,14 +188,23 @@ class MnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
if
self
.
mnli_type
==
"matched"
:
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
return
self
.
_create_examples
(
"dev_matched"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
if
self
.
mnli_type
==
"matched"
:
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
...
@@ -199,9 +216,9 @@ class MnliProcessor(DataProcessor):
return
"MNLI"
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
...
@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
...
@@ -244,9 +261,9 @@ class MrpcProcessor(DataProcessor):
return
"MRPC"
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -290,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -307,7 +324,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -321,7 +338,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k.tsv"
))[
1
:]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k.tsv"
))[
1
:]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
...
@@ -368,9 +385,9 @@ class QnliProcessor(DataProcessor):
return
"QNLI"
return
"QNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
guid
=
"%s-%s"
%
(
set_type
,
1
)
...
@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
...
@@ -415,9 +432,9 @@ class QqpProcessor(DataProcessor):
return
"QQP"
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
...
@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
...
@@ -462,7 +479,7 @@ class RteProcessor(DataProcessor):
return
"RTE"
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
...
@@ -507,9 +524,9 @@ class SstProcessor(DataProcessor):
return
"SST-2"
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
...
@@ -558,7 +575,7 @@ class StsBProcessor(DataProcessor):
return
"STS-B"
return
"STS-B"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
...
@@ -671,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return
"TFDS_"
+
self
.
dataset_name
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
if
split_name
not
in
self
.
dataset
:
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
...
@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
...
@@ -731,7 +748,7 @@ class WnliProcessor(DataProcessor):
return
"WNLI"
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
...
@@ -777,7 +794,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv"
%
language
))[
1
:])
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
...
@@ -792,7 +809,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
...
@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
...
@@ -807,7 +824,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
...
@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -837,7 +854,7 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -851,7 +868,7 @@ class XtremePawsxProcessor(DataProcessor):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -865,7 +882,7 @@ class XtremePawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -896,7 +913,7 @@ class XtremeXnliProcessor(DataProcessor):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -909,7 +926,7 @@ class XtremeXnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
...
@@ -923,7 +940,7 @@ class XtremeXnliProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
...
@@ -1052,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
...
...
official/nlp/data/create_finetuning_data.py
View file @
52515dc3
...
@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
...
@@ -59,27 +59,32 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
"PAWS-X."
)
# XNLI task specific flag.
# MNLI task-specific flag.
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task-specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# Retrieva task
specific flags
# Retrieva
l
task
-
specific flags
.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
"The name of sentence retrieval task for scoring"
)
# Tagging task
specific flags
# Tagging task
-
specific flags
.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
"The name of BERT tagging (token classification) task."
)
# BERT Squad task
specific flags.
# BERT Squad task
-
specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
"The input data file in for generating training data for BERT squad task."
)
...
@@ -179,7 +184,8 @@ def generate_classifier_dataset():
...
@@ -179,7 +184,8 @@ def generate_classifier_dataset():
"cola"
:
"cola"
:
classifier_data_lib
.
ColaProcessor
,
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
"qnli"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment