Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5ad16f95
Commit
5ad16f95
authored
Jun 18, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 380296477
parent
19a49ae3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
159 additions
and
182 deletions
+159
-182
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+151
-180
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+8
-2
No files found.
official/nlp/data/classifier_data_lib.py
View file @
5ad16f95
...
@@ -135,18 +135,22 @@ class AxProcessor(DataProcessor):
...
@@ -135,18 +135,22 @@ class AxProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
train_mnli_dataset
=
tfds
.
load
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
"glue/mnli"
,
split
=
"train"
,
try_gcs
=
True
).
as_numpy_iterator
()
return
self
.
_create_examples_tfds
(
train_mnli_dataset
,
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
val_mnli_dataset
=
tfds
.
load
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
"glue/mnli"
,
split
=
"validation_matched"
,
try_gcs
=
True
).
as_numpy_iterator
()
return
self
.
_create_examples_tfds
(
val_mnli_dataset
,
"validation"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
test_ax_dataset
=
tfds
.
load
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
"glue/ax"
,
split
=
"test"
,
try_gcs
=
True
).
as_numpy_iterator
()
return
self
.
_create_examples_tfds
(
test_ax_dataset
,
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -157,24 +161,20 @@ class AxProcessor(DataProcessor):
...
@@ -157,24 +161,20 @@ class AxProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"AX"
return
"AX"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
dataset
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
text_a_index
=
1
if
set_type
==
"test"
else
8
text_b_index
=
2
if
set_type
==
"test"
else
9
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
# Skip header.
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
i
==
0
:
label
=
"contradiction"
continue
text_a
=
self
.
process_text_fn
(
example
[
"hypothesis"
])
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_b
=
self
.
process_text_fn
(
example
[
"premise"
])
text_a
=
self
.
process_text_fn
(
line
[
text_a_index
])
if
set_type
!=
"test"
:
text_b
=
self
.
process_text_fn
(
line
[
text_b_index
])
label
=
self
.
get_labels
()[
example
[
"label"
]]
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -264,34 +264,28 @@ class MnliProcessor(DataProcessor):
...
@@ -264,34 +264,28 @@ class MnliProcessor(DataProcessor):
mnli_type
=
"matched"
,
mnli_type
=
"matched"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
dataset
=
tfds
.
load
(
"glue/mnli"
,
try_gcs
=
True
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation_matched"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
else
:
else
:
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation_mismatched"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test_matched"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
else
:
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test_mismatched"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -302,21 +296,22 @@ class MnliProcessor(DataProcessor):
...
@@ -302,21 +296,22 @@ class MnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"MNLI"
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/mnli"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
guid
=
"%s-%s"
%
(
set_type
,
i
)
continue
label
=
"contradiction"
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_a
=
self
.
process_text_fn
(
example
[
"hypothesis"
])
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
example
[
"premise"
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
!=
"test"
:
if
set_type
==
"test"
:
label
=
self
.
get_labels
()[
example
[
"label"
]]
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -325,18 +320,15 @@ class MrpcProcessor(DataProcessor):
...
@@ -325,18 +320,15 @@ class MrpcProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -347,21 +339,22 @@ class MrpcProcessor(DataProcessor):
...
@@ -347,21 +339,22 @@ class MrpcProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"MRPC"
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/mrpc"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
"0"
text_b
=
self
.
process_text_fn
(
line
[
4
])
text_a
=
self
.
process_text_fn
(
example
[
"sentence1"
])
if
set_type
==
"test"
:
text_b
=
self
.
process_text_fn
(
example
[
"sentence2"
])
label
=
"0"
if
set_type
!=
"test"
:
else
:
label
=
str
(
example
[
"label"
])
label
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -449,18 +442,15 @@ class QnliProcessor(DataProcessor):
...
@@ -449,18 +442,15 @@ class QnliProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev_matched"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -471,23 +461,22 @@ class QnliProcessor(DataProcessor):
...
@@ -471,23 +461,22 @@ class QnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"QNLI"
return
"QNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/qnli"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
guid
=
"%s-%s"
%
(
set_type
,
i
)
continue
label
=
"entailment"
guid
=
"%s-%s"
%
(
set_type
,
1
)
text_a
=
self
.
process_text_fn
(
example
[
"question"
])
if
set_type
==
"test"
:
text_b
=
self
.
process_text_fn
(
example
[
"sentence"
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
if
set_type
!=
"test"
:
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
self
.
get_labels
()[
example
[
"label"
]]
label
=
"entailment"
else
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -496,18 +485,15 @@ class QqpProcessor(DataProcessor):
...
@@ -496,18 +485,15 @@ class QqpProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -518,27 +504,22 @@ class QqpProcessor(DataProcessor):
...
@@ -518,27 +504,22 @@ class QqpProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"QQP"
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/qqp"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
guid
=
"%s-%s"
%
(
set_type
,
i
)
continue
label
=
"0"
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
text_a
=
self
.
process_text_fn
(
example
[
"question1"
])
if
set_type
==
"test"
:
text_b
=
self
.
process_text_fn
(
example
[
"question2"
])
text_a
=
line
[
1
]
if
set_type
!=
"test"
:
text_b
=
line
[
2
]
label
=
str
(
example
[
"label"
])
label
=
"0"
else
:
# There appear to be some garbage lines in the train dataset.
try
:
text_a
=
line
[
3
]
text_b
=
line
[
4
]
label
=
line
[
5
]
except
IndexError
:
continue
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -547,18 +528,15 @@ class RteProcessor(DataProcessor):
...
@@ -547,18 +528,15 @@ class RteProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -571,21 +549,22 @@ class RteProcessor(DataProcessor):
...
@@ -571,21 +549,22 @@ class RteProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"RTE"
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/rte"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
"entailment"
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
example
[
"sentence1"
])
if
set_type
==
"test"
:
text_b
=
self
.
process_text_fn
(
example
[
"sentence2"
])
label
=
"entailment"
if
set_type
!=
"test"
:
else
:
label
=
self
.
get_labels
()[
example
[
"label"
]]
label
=
tokenization
.
convert_to_unicode
(
line
[
3
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -594,18 +573,15 @@ class SstProcessor(DataProcessor):
...
@@ -594,18 +573,15 @@ class SstProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -616,21 +592,20 @@ class SstProcessor(DataProcessor):
...
@@ -616,21 +592,20 @@ class SstProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"SST-2"
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/sst2"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
label
=
"0"
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
example
[
"sentence"
])
label
=
"0"
if
set_type
!=
"test"
:
else
:
label
=
str
(
example
[
"label"
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
0
])
label
=
tokenization
.
convert_to_unicode
(
line
[
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
@@ -645,18 +620,33 @@ class StsBProcessor(DataProcessor):
...
@@ -645,18 +620,33 @@ class StsBProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
_create_examples_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/stsb"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
for
i
,
example
in
enumerate
(
dataset
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
label
=
0.0
text_a
=
self
.
process_text_fn
(
example
[
"sentence1"
])
text_b
=
self
.
process_text_fn
(
example
[
"sentence2"
])
if
set_type
!=
"test"
:
label
=
self
.
label_type
(
example
[
"label"
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -667,23 +657,6 @@ class StsBProcessor(DataProcessor):
...
@@ -667,23 +657,6 @@ class StsBProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"STS-B"
return
"STS-B"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
7
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
8
])
if
set_type
==
"test"
:
label
=
0.0
else
:
label
=
self
.
label_type
(
tokenization
.
convert_to_unicode
(
line
[
9
]))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
TfdsProcessor
(
DataProcessor
):
class
TfdsProcessor
(
DataProcessor
):
"""Processor for generic text classification and regression TFDS data set.
"""Processor for generic text classification and regression TFDS data set.
...
@@ -818,18 +791,15 @@ class WnliProcessor(DataProcessor):
...
@@ -818,18 +791,15 @@ class WnliProcessor(DataProcessor):
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"validation"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples_tfds
(
"test"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
...
@@ -840,21 +810,22 @@ class WnliProcessor(DataProcessor):
...
@@ -840,21 +810,22 @@ class WnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
return
"WNLI"
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
_tfds
(
self
,
set_type
):
"""Creates examples for the training/dev/test sets."""
"""Creates examples for the training/dev/test sets."""
dataset
=
tfds
.
load
(
"glue/wnli"
,
split
=
set_type
,
try_gcs
=
True
).
as_numpy_iterator
()
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
example
in
enumerate
(
dataset
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
"0"
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
example
[
"sentence1"
])
if
set_type
==
"test"
:
text_b
=
self
.
process_text_fn
(
example
[
"sentence2"
])
label
=
"0"
if
set_type
!=
"test"
:
else
:
label
=
str
(
example
[
"label"
])
label
=
tokenization
.
convert_to_unicode
(
line
[
3
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
None
))
return
examples
return
examples
...
...
official/nlp/data/create_finetuning_data.py
View file @
5ad16f95
...
@@ -173,8 +173,14 @@ flags.DEFINE_string(
...
@@ -173,8 +173,14 @@ flags.DEFINE_string(
def
generate_classifier_dataset
():
def
generate_classifier_dataset
():
"""Generates classifier dataset and returns input meta data."""
"""Generates classifier dataset and returns input meta data."""
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
or
if
FLAGS
.
classification_task_name
in
[
FLAGS
.
tfds_params
)
"COLA"
,
"WNLI"
,
"SST-2"
,
"MRPC"
,
"QQP"
,
"STS-B"
,
"MNLI"
,
"QNLI"
,
"RTE"
,
"AX"
]:
assert
not
FLAGS
.
input_data_dir
or
FLAGS
.
tfds_params
else
:
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
or
FLAGS
.
tfds_params
)
if
FLAGS
.
tokenization
==
"WordPiece"
:
if
FLAGS
.
tokenization
==
"WordPiece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment