Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
431a9ca3
Commit
431a9ca3
authored
Feb 01, 2021
by
stephenwu
Browse files
added AX-g preprocessor
parent
80993c41
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
8 deletions
+53
-8
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+48
-6
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+5
-2
No files found.
official/nlp/data/classifier_data_lib.py
View file @
431a9ca3
...
...
@@ -18,6 +18,7 @@ import collections
import
csv
import
importlib
import
os
import
json
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -1275,6 +1276,46 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
return
feature
class
AXgProcessor
(
DataProcessor
):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"dev.jsonl"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"test.jsonl"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"AXg"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
line
in
lines
:
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
str
(
line
[
'idx'
])))
text_a
=
self
.
process_text_fn
(
line
[
"hypothesis"
])
text_b
=
self
.
process_text_fn
(
line
[
"premise"
])
label
=
self
.
process_text_fn
(
line
[
"label"
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
_read_jsonl
(
self
,
input_path
):
with
tf
.
io
.
gfile
.
GFile
(
input_path
,
"r"
)
as
f
:
lines
=
[]
for
json_str
in
f
:
lines
.
append
(
json
.
loads
(
json_str
))
return
lines
def
file_based_convert_examples_to_features
(
examples
,
label_list
,
...
...
@@ -1374,13 +1415,14 @@ def generate_tf_record_from_data_file(processor,
label_type
=
getattr
(
processor
,
"label_type"
,
None
)
is_regression
=
getattr
(
processor
,
"is_regression"
,
False
)
has_sample_weights
=
getattr
(
processor
,
"weight_key"
,
False
)
assert
train_data_output_path
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
file_based_convert_examples_to_features
(
train_input_data_examples
,
label_list
,
max_seq_length
,
tokenizer
,
train_data_output_path
,
label_type
)
num_training_data
=
len
(
train_input_data_examples
)
num_training_data
=
0
if
train_data_output_path
:
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
file_based_convert_examples_to_features
(
train_input_data_examples
,
label_list
,
max_seq_length
,
tokenizer
,
train_data_output_path
,
label_type
)
num_training_data
=
len
(
train_input_data_examples
)
if
eval_data_output_path
:
eval_input_data_examples
=
processor
.
get_dev_examples
(
data_dir
)
...
...
official/nlp/data/create_finetuning_data.py
View file @
431a9ca3
...
...
@@ -49,7 +49,7 @@ flags.DEFINE_string(
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"AX"
,
"COLA"
,
"IMDB"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
,
"AX-g"
],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
...
...
@@ -238,7 +238,10 @@ def generate_classifier_dataset():
functools
.
partial
(
classifier_data_lib
.
XtremePawsxProcessor
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
)
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"ax-g"
:
classifier_data_lib
.
AXgProcessor
}
task_name
=
FLAGS
.
classification_task_name
.
lower
()
if
task_name
not
in
processors
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment