Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ad423d06
Commit
ad423d06
authored
Jun 11, 2020
by
Maxim Neumann
Committed by
A. Unique TensorFlower
Jun 11, 2020
Browse files
Internal change
PiperOrigin-RevId: 315855426
parent
0b23ad50
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
5 deletions
+22
-5
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+22
-5
No files found.
official/nlp/data/classifier_data_lib.py
View file @
ad423d06
...
...
@@ -33,7 +33,7 @@ from official.nlp.bert import tokenization
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
,
weight
=
None
):
"""Constructs a InputExample.
Args:
...
...
@@ -44,11 +44,14 @@ class InputExample(object):
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
"""
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
label
=
label
self
.
weight
=
weight
class
InputFeatures
(
object
):
...
...
@@ -59,12 +62,14 @@ class InputFeatures(object):
input_mask
,
segment_ids
,
label_id
,
is_real_example
=
True
):
is_real_example
=
True
,
weight
=
None
):
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
self
.
is_real_example
=
is_real_example
self
.
weight
=
weight
class
DataProcessor
(
object
):
...
...
@@ -574,6 +579,7 @@ class TfdsProcessor(DataProcessor):
test_text_b_key: Key of the second text feature to use in test set.
test_label: String to be used as the label for all test examples.
label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False).
"""
...
...
@@ -612,6 +618,7 @@ class TfdsProcessor(DataProcessor):
self
.
test_label
=
d
.
get
(
"test_label"
,
"test_example"
)
self
.
label_type
=
dtype_map
[
d
.
get
(
"label_type"
,
"int"
)]
self
.
is_regression
=
cast_str_to_bool
(
d
.
get
(
"is_regression"
,
"False"
))
self
.
weight_key
=
d
.
get
(
"weight_key"
,
None
)
def
get_train_examples
(
self
,
data_dir
):
assert
data_dir
is
None
...
...
@@ -637,7 +644,7 @@ class TfdsProcessor(DataProcessor):
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
examples
=
[]
text_b
=
None
text_b
,
weight
=
None
,
None
for
i
,
example
in
enumerate
(
dataset
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
...
...
@@ -650,8 +657,11 @@ class TfdsProcessor(DataProcessor):
if
self
.
text_b_key
:
text_b
=
self
.
process_text_fn
(
example
[
self
.
text_b_key
])
label
=
self
.
label_type
(
example
[
self
.
label_key
])
if
self
.
weight_key
:
weight
=
float
(
example
[
self
.
weight_key
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
,
weight
=
weight
))
return
examples
...
...
@@ -739,13 +749,15 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"label: %s (id = %d)"
,
example
.
label
,
label_id
)
logging
.
info
(
"weight: %s"
,
example
.
weight
)
feature
=
InputFeatures
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
label_id
=
label_id
,
is_real_example
=
True
)
is_real_example
=
True
,
weight
=
example
.
weight
)
return
feature
...
...
@@ -781,6 +793,8 @@ def file_based_convert_examples_to_features(examples, label_list,
features
[
"label_ids"
]
=
create_int_feature
([
feature
.
label_id
])
features
[
"is_real_example"
]
=
create_int_feature
(
[
int
(
feature
.
is_real_example
)])
if
feature
.
weight
is
not
None
:
features
[
"weight"
]
=
create_float_feature
([
feature
.
weight
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
...
...
@@ -837,6 +851,7 @@ def generate_tf_record_from_data_file(processor,
label_list
=
processor
.
get_labels
()
label_type
=
getattr
(
processor
,
"label_type"
,
None
)
is_regression
=
getattr
(
processor
,
"is_regression"
,
False
)
has_sample_weights
=
getattr
(
processor
,
"weight_key"
,
False
)
assert
train_data_output_path
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
...
...
@@ -879,6 +894,8 @@ def generate_tf_record_from_data_file(processor,
else
:
meta_data
[
"task_type"
]
=
"bert_classification"
meta_data
[
"num_labels"
]
=
len
(
processor
.
get_labels
())
if
has_sample_weights
:
meta_data
[
"has_sample_weights"
]
=
True
if
eval_data_output_path
:
meta_data
[
"eval_data_size"
]
=
len
(
eval_input_data_examples
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment