Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
431a9ca3
Commit
431a9ca3
authored
Feb 01, 2021
by
stephenwu
Browse files
added AX-g preprocessor
parent
80993c41
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
8 deletions
+53
-8
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+48
-6
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+5
-2
No files found.
official/nlp/data/classifier_data_lib.py
View file @
431a9ca3
...
@@ -18,6 +18,7 @@ import collections
...
@@ -18,6 +18,7 @@ import collections
import
csv
import
csv
import
importlib
import
importlib
import
os
import
os
import
json
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -1275,6 +1276,46 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
...
@@ -1275,6 +1276,46 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return
feature
return
feature
class
AXgProcessor
(
DataProcessor
):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"dev.jsonl"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"test.jsonl"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"AXg"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
line
in
lines
:
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
str
(
line
[
'idx'
])))
text_a
=
self
.
process_text_fn
(
line
[
"hypothesis"
])
text_b
=
self
.
process_text_fn
(
line
[
"premise"
])
label
=
self
.
process_text_fn
(
line
[
"label"
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
_read_jsonl
(
self
,
input_path
):
with
tf
.
io
.
gfile
.
GFile
(
input_path
,
"r"
)
as
f
:
lines
=
[]
for
json_str
in
f
:
lines
.
append
(
json
.
loads
(
json_str
))
return
lines
def
file_based_convert_examples_to_features
(
examples
,
def
file_based_convert_examples_to_features
(
examples
,
label_list
,
label_list
,
...
@@ -1374,13 +1415,14 @@ def generate_tf_record_from_data_file(processor,
...
@@ -1374,13 +1415,14 @@ def generate_tf_record_from_data_file(processor,
label_type
=
getattr
(
processor
,
"label_type"
,
None
)
label_type
=
getattr
(
processor
,
"label_type"
,
None
)
is_regression
=
getattr
(
processor
,
"is_regression"
,
False
)
is_regression
=
getattr
(
processor
,
"is_regression"
,
False
)
has_sample_weights
=
getattr
(
processor
,
"weight_key"
,
False
)
has_sample_weights
=
getattr
(
processor
,
"weight_key"
,
False
)
assert
train_data_output_path
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
num_training_data
=
0
file_based_convert_examples_to_features
(
train_input_data_examples
,
label_list
,
if
train_data_output_path
:
max_seq_length
,
tokenizer
,
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
train_data_output_path
,
label_type
)
file_based_convert_examples_to_features
(
train_input_data_examples
,
label_list
,
num_training_data
=
len
(
train_input_data_examples
)
max_seq_length
,
tokenizer
,
train_data_output_path
,
label_type
)
num_training_data
=
len
(
train_input_data_examples
)
if
eval_data_output_path
:
if
eval_data_output_path
:
eval_input_data_examples
=
processor
.
get_dev_examples
(
data_dir
)
eval_input_data_examples
=
processor
.
get_dev_examples
(
data_dir
)
...
...
official/nlp/data/create_finetuning_data.py
View file @
431a9ca3
...
@@ -49,7 +49,7 @@ flags.DEFINE_string(
...
@@ -49,7 +49,7 @@ flags.DEFINE_string(
flags
.
DEFINE_enum
(
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"classification_task_name"
,
"MNLI"
,
[
"AX"
,
"COLA"
,
"IMDB"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"AX"
,
"COLA"
,
"IMDB"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
,
"AX-g"
],
"The name of the task to train BERT classifier. The "
],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"of input tsv files; 2. the dev set for XTREME is english "
...
@@ -238,7 +238,10 @@ def generate_classifier_dataset():
...
@@ -238,7 +238,10 @@ def generate_classifier_dataset():
functools
.
partial
(
functools
.
partial
(
classifier_data_lib
.
XtremePawsxProcessor
,
classifier_data_lib
.
XtremePawsxProcessor
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
)
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"ax-g"
:
classifier_data_lib
.
AXgProcessor
}
}
task_name
=
FLAGS
.
classification_task_name
.
lower
()
task_name
=
FLAGS
.
classification_task_name
.
lower
()
if
task_name
not
in
processors
:
if
task_name
not
in
processors
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment