Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d697041a
Commit
d697041a
authored
Mar 19, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 301915584
parent
19d930c3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
133 additions
and
21 deletions
+133
-21
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+94
-0
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+39
-21
No files found.
official/nlp/data/classifier_data_lib.py
View file @
d697041a
...
@@ -24,6 +24,7 @@ import os
...
@@ -24,6 +24,7 @@ import os
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.nlp.bert
import
tokenization
from
official.nlp.bert
import
tokenization
...
@@ -386,6 +387,99 @@ class QnliProcessor(DataProcessor):
...
@@ -386,6 +387,99 @@ class QnliProcessor(DataProcessor):
return
examples
return
examples
class
TfdsProcessor
(
DataProcessor
):
"""Processor for generic text classification TFDS data set.
The TFDS parameters are expected to be provided in the tfds_params string, in
a comma-separated list of parameter assignments.
Examples:
tfds_params="dataset=scicite,text_key=string"
tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
tfds_params="dataset=glue/cola,text_key=sentence"
tfds_params="dataset=glue/sst2,text_key=sentence"
tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
data_dir: Optional TFDS source root directory.
train_split: Name of the train split (defaults to `train`).
dev_split: Name of the dev split (defaults to `validation`).
test_split: Name of the test split (defaults to `test`).
text_key: Key of the text_a feature (defaults to `text`).
text_b_key: Key of the second text feature if available.
label_key: Key of the label feature (defaults to `label`).
test_text_key: Key of the text feature to use in test set.
test_text_b_key: Key of the second text feature to use in test set.
test_label: String to be used as the label for all test examples.
"""
def
__init__
(
self
,
tfds_params
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
TfdsProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
_process_tfds_params_str
(
tfds_params
)
self
.
dataset
,
info
=
tfds
.
load
(
self
.
dataset_name
,
data_dir
=
self
.
data_dir
,
with_info
=
True
)
self
.
_labels
=
list
(
range
(
info
.
features
[
self
.
label_key
].
num_classes
))
def
_process_tfds_params_str
(
self
,
params_str
):
"""Extracts TFDS parameters from a comma-separated assignements string."""
tuples
=
[
x
.
split
(
"="
)
for
x
in
params_str
.
split
(
","
)]
d
=
{
k
.
strip
():
v
.
strip
()
for
k
,
v
in
tuples
}
self
.
dataset_name
=
d
[
"dataset"
]
# Required.
self
.
data_dir
=
d
.
get
(
"data_dir"
,
None
)
self
.
train_split
=
d
.
get
(
"train_split"
,
"train"
)
self
.
dev_split
=
d
.
get
(
"dev_split"
,
"validation"
)
self
.
test_split
=
d
.
get
(
"test_split"
,
"test"
)
self
.
text_key
=
d
.
get
(
"text_key"
,
"text"
)
self
.
text_b_key
=
d
.
get
(
"text_b_key"
,
None
)
self
.
label_key
=
d
.
get
(
"label_key"
,
"label"
)
self
.
test_text_key
=
d
.
get
(
"test_text_key"
,
self
.
text_key
)
self
.
test_text_b_key
=
d
.
get
(
"test_text_b_key"
,
self
.
text_b_key
)
self
.
test_label
=
d
.
get
(
"test_label"
,
"test_example"
)
def
get_train_examples
(
self
,
data_dir
):
assert
data_dir
is
None
return
self
.
_create_examples
(
self
.
train_split
,
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
assert
data_dir
is
None
return
self
.
_create_examples
(
self
.
dev_split
,
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
assert
data_dir
is
None
return
self
.
_create_examples
(
self
.
test_split
,
"test"
)
def
get_labels
(
self
):
return
self
.
_labels
def
get_processor_name
(
self
):
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training and dev sets."""
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
examples
=
[]
text_b
=
None
for
i
,
example
in
enumerate
(
dataset
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
self
.
process_text_fn
(
example
[
self
.
test_text_key
])
if
self
.
test_text_b_key
:
text_b
=
self
.
process_text_fn
(
example
[
self
.
test_text_b_key
])
label
=
self
.
test_label
else
:
text_a
=
self
.
process_text_fn
(
example
[
self
.
text_key
])
if
self
.
text_b_key
:
text_b
=
self
.
process_text_fn
(
example
[
self
.
text_b_key
])
label
=
int
(
example
[
self
.
label_key
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
):
tokenizer
):
"""Converts a single `InputExample` into a single `InputFeatures`."""
"""Converts a single `InputExample` into a single `InputFeatures`."""
...
...
official/nlp/data/create_finetuning_data.py
View file @
d697041a
...
@@ -104,22 +104,16 @@ flags.DEFINE_enum(
...
@@ -104,22 +104,16 @@ flags.DEFINE_enum(
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"or sentence_piece tokenizer. Canonical BERT uses word_piece tokenizer, "
"while ALBERT uses sentence_piece tokenizer."
)
"while ALBERT uses sentence_piece tokenizer."
)
flags
.
DEFINE_string
(
"tfds_params"
,
""
,
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation)."
)
def
generate_classifier_dataset
():
def
generate_classifier_dataset
():
"""Generates classifier dataset and returns input meta data."""
"""Generates classifier dataset and returns input meta data."""
assert
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
or
FLAGS
.
tfds_params
)
processors
=
{
"cola"
:
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
classifier_data_lib
.
QnliProcessor
,
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
"xnli"
:
classifier_data_lib
.
XnliProcessor
,
}
task_name
=
FLAGS
.
classification_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
(
task_name
))
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
...
@@ -131,14 +125,38 @@ def generate_classifier_dataset():
...
@@ -131,14 +125,38 @@ def generate_classifier_dataset():
processor_text_fn
=
functools
.
partial
(
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
processor
=
processors
[
task_name
](
processor_text_fn
)
if
FLAGS
.
tfds_params
:
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
processor
=
classifier_data_lib
.
TfdsProcessor
(
processor
,
tfds_params
=
FLAGS
.
tfds_params
,
FLAGS
.
input_data_dir
,
process_text_fn
=
processor_text_fn
)
tokenizer
,
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
train_data_output_path
=
FLAGS
.
train_data_output_path
,
processor
,
eval_data_output_path
=
FLAGS
.
eval_data_output_path
,
None
,
max_seq_length
=
FLAGS
.
max_seq_length
)
tokenizer
,
train_data_output_path
=
FLAGS
.
train_data_output_path
,
eval_data_output_path
=
FLAGS
.
eval_data_output_path
,
max_seq_length
=
FLAGS
.
max_seq_length
)
else
:
processors
=
{
"cola"
:
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
classifier_data_lib
.
QnliProcessor
,
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
"xnli"
:
classifier_data_lib
.
XnliProcessor
,
}
task_name
=
FLAGS
.
classification_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
(
task_name
))
processor
=
processors
[
task_name
](
processor_text_fn
)
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
train_data_output_path
=
FLAGS
.
train_data_output_path
,
eval_data_output_path
=
FLAGS
.
eval_data_output_path
,
max_seq_length
=
FLAGS
.
max_seq_length
)
def
generate_squad_dataset
():
def
generate_squad_dataset
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment