Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6c522f8b
Commit
6c522f8b
authored
Nov 17, 2017
by
Vivek Rathod
Browse files
update offline metrics for Open Images dataset.
parent
11e9c7ad
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
627 additions
and
0 deletions
+627
-0
research/object_detection/metrics/BUILD
research/object_detection/metrics/BUILD
+55
-0
research/object_detection/metrics/offline_eval_map_corloc.py
research/object_detection/metrics/offline_eval_map_corloc.py
+173
-0
research/object_detection/metrics/offline_eval_map_corloc_test.py
.../object_detection/metrics/offline_eval_map_corloc_test.py
+58
-0
research/object_detection/metrics/tf_example_parser.py
research/object_detection/metrics/tf_example_parser.py
+155
-0
research/object_detection/metrics/tf_example_parser_test.py
research/object_detection/metrics/tf_example_parser_test.py
+186
-0
No files found.
research/object_detection/metrics/BUILD
0 → 100644
View file @
6c522f8b
# Tensorflow Object Detection API: main runnables.
package
(
default_visibility
=
[
"//visibility:public"
],
)
licenses
([
"notice"
])
# Apache 2.0
py_binary
(
name
=
"offline_eval_map_corloc"
,
srcs
=
[
"offline_eval_map_corloc.py"
,
],
deps
=
[
":tf_example_parser"
,
"//tensorflow_models/object_detection:evaluator"
,
"//tensorflow_models/object_detection/builders:input_reader_builder"
,
"//tensorflow_models/object_detection/core:standard_fields"
,
"//tensorflow_models/object_detection/utils:config_util"
,
"//tensorflow_models/object_detection/utils:label_map_util"
,
],
)
py_test
(
name
=
"offline_eval_map_corloc_test"
,
srcs
=
[
"offline_eval_map_corloc_test.py"
,
],
deps
=
[
":offline_eval_map_corloc"
,
"//tensorflow"
,
],
)
py_library
(
name
=
"tf_example_parser"
,
srcs
=
[
"tf_example_parser.py"
],
deps
=
[
"//tensorflow"
,
"//tensorflow_models/object_detection/core:data_parser"
,
"//tensorflow_models/object_detection/core:standard_fields"
,
],
)
py_test
(
name
=
"tf_example_parser_test"
,
srcs
=
[
"tf_example_parser_test.py"
],
deps
=
[
":tf_example_parser"
,
"//tensorflow"
,
"//tensorflow_models/object_detection/core:standard_fields"
,
],
)
research/object_detection/metrics/offline_eval_map_corloc.py
0 → 100644
View file @
6c522f8b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Evaluation executable for detection data.
This executable evaluates precomputed detections produced by a detection
model and writes the evaluation results into csv file metrics.csv, stored
in the directory, specified by --eval_dir.
The evaluation metrics set is supplied in object_detection.protos.EvalConfig
in metrics_set field.
Currently two set of metrics are supported:
- pascal_voc_metrics: standard PASCAL VOC 2007 metric
- open_images_metrics: Open Image V2 metric
All other field of object_detection.protos.EvalConfig are ignored.
Example usage:
./compute_metrics \
--eval_dir=path/to/eval_dir \
--eval_config_path=path/to/evaluation/configuration/file \
--input_config_path=path/to/input/configuration/file
"""
import
csv
import
os
import
re
import
tensorflow
as
tf
from
object_detection
import
evaluator
from
object_detection.core
import
standard_fields
from
object_detection.metrics
import
tf_example_parser
from
object_detection.utils
import
config_util
from
object_detection.utils
import
label_map_util
flags
=
tf
.
app
.
flags
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
flags
.
DEFINE_string
(
'eval_dir'
,
None
,
'Directory to write eval summaries to.'
)
flags
.
DEFINE_string
(
'eval_config_path'
,
None
,
'Path to an eval_pb2.EvalConfig config file.'
)
flags
.
DEFINE_string
(
'input_config_path'
,
None
,
'Path to an eval_pb2.InputConfig config file.'
)
FLAGS
=
flags
.
FLAGS
def
_generate_sharded_filenames
(
filename
):
m
=
re
.
search
(
r
'@(\d{1,})'
,
filename
)
if
m
:
num_shards
=
int
(
m
.
group
(
1
))
return
[
re
.
sub
(
r
'@(\d{1,})'
,
'-%.5d-of-%.5d'
%
(
i
,
num_shards
),
filename
)
for
i
in
range
(
num_shards
)
]
else
:
return
[
filename
]
def
_generate_filenames
(
filenames
):
result
=
[]
for
filename
in
filenames
:
result
+=
_generate_sharded_filenames
(
filename
)
return
result
def
read_data_and_evaluate
(
input_config
,
eval_config
):
"""Reads pre-computed object detections and groundtruth from tf_record.
Args:
input_config: input config proto of type
object_detection.protos.InputReader.
eval_config: evaluation config proto of type
object_detection.protos.EvalConfig.
Returns:
Evaluated detections metrics.
Raises:
ValueError: if input_reader type is not supported or metric type is unknown.
"""
if
input_config
.
WhichOneof
(
'input_reader'
)
==
'tf_record_input_reader'
:
input_paths
=
input_config
.
tf_record_input_reader
.
input_path
label_map
=
label_map_util
.
load_labelmap
(
input_config
.
label_map_path
)
max_num_classes
=
max
([
item
.
id
for
item
in
label_map
.
item
])
categories
=
label_map_util
.
convert_label_map_to_categories
(
label_map
,
max_num_classes
)
object_detection_evaluators
=
evaluator
.
get_evaluators
(
eval_config
,
categories
)
# Support a single evaluator
object_detection_evaluator
=
object_detection_evaluators
[
0
]
skipped_images
=
0
processed_images
=
0
for
input_path
in
_generate_filenames
(
input_paths
):
tf
.
logging
.
info
(
'Processing file: {0}'
.
format
(
input_path
))
record_iterator
=
tf
.
python_io
.
tf_record_iterator
(
path
=
input_path
)
data_parser
=
tf_example_parser
.
TfExampleDetectionAndGTParser
()
for
string_record
in
record_iterator
:
tf
.
logging
.
log_every_n
(
tf
.
logging
.
INFO
,
'Processed %d images...'
,
1000
,
processed_images
)
processed_images
+=
1
example
=
tf
.
train
.
Example
()
example
.
ParseFromString
(
string_record
)
decoded_dict
=
data_parser
.
parse
(
example
)
if
decoded_dict
:
object_detection_evaluator
.
add_single_ground_truth_image_info
(
decoded_dict
[
standard_fields
.
DetectionResultFields
.
key
],
decoded_dict
)
object_detection_evaluator
.
add_single_detected_image_info
(
decoded_dict
[
standard_fields
.
DetectionResultFields
.
key
],
decoded_dict
)
else
:
skipped_images
+=
1
tf
.
logging
.
info
(
'Skipped images: {0}'
.
format
(
skipped_images
))
return
object_detection_evaluator
.
evaluate
()
raise
ValueError
(
'Unsupported input_reader_config.'
)
def
write_metrics
(
metrics
,
output_dir
):
"""Write metrics to the output directory.
Args:
metrics: A dictionary containing metric names and values.
output_dir: Directory to write metrics to.
"""
tf
.
logging
.
info
(
'Writing metrics.'
)
with
open
(
os
.
path
.
join
(
output_dir
,
'metrics.csv'
),
'w'
)
as
csvfile
:
metrics_writer
=
csv
.
writer
(
csvfile
,
delimiter
=
','
)
for
metric_name
,
metric_value
in
metrics
.
items
():
metrics_writer
.
writerow
([
metric_name
,
str
(
metric_value
)])
def
main
(
argv
):
del
argv
required_flags
=
[
'input_config_path'
,
'eval_config_path'
,
'eval_dir'
]
for
flag_name
in
required_flags
:
if
not
getattr
(
FLAGS
,
flag_name
):
raise
ValueError
(
'Flag --{} is required'
.
format
(
flag_name
))
configs
=
config_util
.
get_configs_from_multiple_files
(
eval_input_config_path
=
FLAGS
.
input_config_path
,
eval_config_path
=
FLAGS
.
eval_config_path
)
eval_config
=
configs
[
'eval_config'
]
input_config
=
configs
[
'eval_input_config'
]
metrics
=
read_data_and_evaluate
(
input_config
,
eval_config
)
# Save metrics
write_metrics
(
metrics
,
FLAGS
.
eval_dir
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
(
main
)
research/object_detection/metrics/offline_eval_map_corloc_test.py
0 → 100644
View file @
6c522f8b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities in offline_eval_map_corloc binary."""
import
tensorflow
as
tf
from
object_detection.metrics
import
offline_eval_map_corloc
as
offline_eval
class
OfflineEvalMapCorlocTest
(
tf
.
test
.
TestCase
):
def
test_generateShardedFilenames
(
self
):
test_filename
=
'/path/to/file'
result
=
offline_eval
.
_generate_sharded_filenames
(
test_filename
)
self
.
assertEqual
(
result
,
[
test_filename
])
test_filename
=
'/path/to/file-00000-of-00050'
result
=
offline_eval
.
_generate_sharded_filenames
(
test_filename
)
self
.
assertEqual
(
result
,
[
test_filename
])
result
=
offline_eval
.
_generate_sharded_filenames
(
'/path/to/@3.sst'
)
self
.
assertEqual
(
result
,
[
'/path/to/-00000-of-00003.sst'
,
'/path/to/-00001-of-00003.sst'
,
'/path/to/-00002-of-00003.sst'
])
result
=
offline_eval
.
_generate_sharded_filenames
(
'/path/to/abc@3'
)
self
.
assertEqual
(
result
,
[
'/path/to/abc-00000-of-00003'
,
'/path/to/abc-00001-of-00003'
,
'/path/to/abc-00002-of-00003'
])
result
=
offline_eval
.
_generate_sharded_filenames
(
'/path/to/@1'
)
self
.
assertEqual
(
result
,
[
'/path/to/-00000-of-00001'
])
def
test_generateFilenames
(
self
):
test_filenames
=
[
'/path/to/file'
,
'/path/to/@3.sst'
]
result
=
offline_eval
.
_generate_filenames
(
test_filenames
)
self
.
assertEqual
(
result
,
[
'/path/to/file'
,
'/path/to/-00000-of-00003.sst'
,
'/path/to/-00001-of-00003.sst'
,
'/path/to/-00002-of-00003.sst'
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/metrics/tf_example_parser.py
0 → 100644
View file @
6c522f8b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow Example proto parser for data loading.
A parser to decode data containing serialized tensorflow.Example
protos into materialized tensors (numpy arrays).
"""
import
numpy
as
np
from
object_detection.core
import
data_parser
from
object_detection.core
import
standard_fields
as
fields
class
FloatParser
(
data_parser
.
DataToNumpyParser
):
"""Tensorflow Example float parser."""
def
__init__
(
self
,
field_name
):
self
.
field_name
=
field_name
def
parse
(
self
,
tf_example
):
return
np
.
array
(
tf_example
.
features
.
feature
[
self
.
field_name
].
float_list
.
value
,
dtype
=
np
.
float
).
transpose
()
if
tf_example
.
features
.
feature
[
self
.
field_name
].
HasField
(
"float_list"
)
else
None
class
StringParser
(
data_parser
.
DataToNumpyParser
):
"""Tensorflow Example string parser."""
def
__init__
(
self
,
field_name
):
self
.
field_name
=
field_name
def
parse
(
self
,
tf_example
):
return
""
.
join
(
tf_example
.
features
.
feature
[
self
.
field_name
]
.
bytes_list
.
value
)
if
tf_example
.
features
.
feature
[
self
.
field_name
].
HasField
(
"bytes_list"
)
else
None
class
Int64Parser
(
data_parser
.
DataToNumpyParser
):
"""Tensorflow Example int64 parser."""
def
__init__
(
self
,
field_name
):
self
.
field_name
=
field_name
def
parse
(
self
,
tf_example
):
return
np
.
array
(
tf_example
.
features
.
feature
[
self
.
field_name
].
int64_list
.
value
,
dtype
=
np
.
int64
).
transpose
()
if
tf_example
.
features
.
feature
[
self
.
field_name
].
HasField
(
"int64_list"
)
else
None
class
BoundingBoxParser
(
data_parser
.
DataToNumpyParser
):
"""Tensorflow Example bounding box parser."""
def
__init__
(
self
,
xmin_field_name
,
ymin_field_name
,
xmax_field_name
,
ymax_field_name
):
self
.
field_names
=
[
ymin_field_name
,
xmin_field_name
,
ymax_field_name
,
xmax_field_name
]
def
parse
(
self
,
tf_example
):
result
=
[]
parsed
=
True
for
field_name
in
self
.
field_names
:
result
.
append
(
tf_example
.
features
.
feature
[
field_name
].
float_list
.
value
)
parsed
&=
(
tf_example
.
features
.
feature
[
field_name
].
HasField
(
"float_list"
))
return
np
.
array
(
result
).
transpose
()
if
parsed
else
None
class
TfExampleDetectionAndGTParser
(
data_parser
.
DataToNumpyParser
):
"""Tensorflow Example proto parser."""
def
__init__
(
self
):
self
.
items_to_handlers
=
{
fields
.
DetectionResultFields
.
key
:
StringParser
(
fields
.
TfExampleFields
.
source_id
),
# Object ground truth boxes and classes.
fields
.
InputDataFields
.
groundtruth_boxes
:
(
BoundingBoxParser
(
fields
.
TfExampleFields
.
object_bbox_xmin
,
fields
.
TfExampleFields
.
object_bbox_ymin
,
fields
.
TfExampleFields
.
object_bbox_xmax
,
fields
.
TfExampleFields
.
object_bbox_ymax
)),
fields
.
InputDataFields
.
groundtruth_classes
:
(
Int64Parser
(
fields
.
TfExampleFields
.
object_class_label
)),
# Object detections.
fields
.
DetectionResultFields
.
detection_boxes
:
(
BoundingBoxParser
(
fields
.
TfExampleFields
.
detection_bbox_xmin
,
fields
.
TfExampleFields
.
detection_bbox_ymin
,
fields
.
TfExampleFields
.
detection_bbox_xmax
,
fields
.
TfExampleFields
.
detection_bbox_ymax
)),
fields
.
DetectionResultFields
.
detection_classes
:
(
Int64Parser
(
fields
.
TfExampleFields
.
detection_class_label
)),
fields
.
DetectionResultFields
.
detection_scores
:
(
FloatParser
(
fields
.
TfExampleFields
.
detection_score
)),
}
self
.
optional_items_to_handlers
=
{
fields
.
InputDataFields
.
groundtruth_difficult
:
Int64Parser
(
fields
.
TfExampleFields
.
object_difficult
),
fields
.
InputDataFields
.
groundtruth_group_of
:
Int64Parser
(
fields
.
TfExampleFields
.
object_group_of
)
}
def
parse
(
self
,
tf_example
):
"""Parses tensorflow example and returns a tensor dictionary.
Args:
tf_example: a tf.Example object.
Returns:
A dictionary of the following numpy arrays:
fields.DetectionResultFields.source_id - string containing original image
id.
fields.InputDataFields.groundtruth_boxes - a numpy array containing
groundtruth boxes.
fields.InputDataFields.groundtruth_classes - a numpy array containing
groundtruth classes.
fields.InputDataFields.groundtruth_group_of - a numpy array containing
groundtruth group of flag (optional, None if not specified).
fields.InputDataFields.groundtruth_difficult - a numpy array containing
groundtruth difficult flag (optional, None if not specified).
fields.DetectionResultFields.detection_boxes - a numpy array containing
detection boxes.
fields.DetectionResultFields.detection_classes - a numpy array containing
detection class labels.
fields.DetectionResultFields.detection_scores - a numpy array containing
detection scores.
Returns None if tf.Example was not parsed or non-optional fields were not
found.
"""
results_dict
=
{}
parsed
=
True
for
key
,
parser
in
self
.
items_to_handlers
.
items
():
results_dict
[
key
]
=
parser
.
parse
(
tf_example
)
parsed
&=
(
results_dict
[
key
]
is
not
None
)
for
key
,
parser
in
self
.
optional_items_to_handlers
.
items
():
results_dict
[
key
]
=
parser
.
parse
(
tf_example
)
return
results_dict
if
parsed
else
None
research/object_detection/metrics/tf_example_parser_test.py
0 → 100644
View file @
6c522f8b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.data_decoders.tf_example_parser."""
import
numpy
as
np
import
numpy.testing
as
np_testing
import
tensorflow
as
tf
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.metrics
import
tf_example_parser
class
TfExampleDecoderTest
(
tf
.
test
.
TestCase
):
def
_Int64Feature
(
self
,
value
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
def
_FloatFeature
(
self
,
value
):
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
def
_BytesFeature
(
self
,
value
):
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
def
testParseDetectionsAndGT
(
self
):
source_id
=
'abc.jpg'
# y_min, x_min, y_max, x_max
object_bb
=
np
.
array
([[
0.0
,
0.5
,
0.3
],
[
0.0
,
0.1
,
0.6
],
[
1.0
,
0.6
,
0.8
],
[
1.0
,
0.6
,
0.7
]]).
transpose
()
detection_bb
=
np
.
array
([[
0.1
,
0.2
],
[
0.0
,
0.8
],
[
1.0
,
0.6
],
[
1.0
,
0.85
]]).
transpose
()
object_class_label
=
[
1
,
1
,
2
]
object_difficult
=
[
1
,
0
,
0
]
object_group_of
=
[
0
,
0
,
1
]
detection_class_label
=
[
2
,
1
]
detection_score
=
[
0.5
,
0.3
]
features
=
{
fields
.
TfExampleFields
.
source_id
:
self
.
_BytesFeature
(
source_id
),
fields
.
TfExampleFields
.
object_bbox_ymin
:
self
.
_FloatFeature
(
object_bb
[:,
0
].
tolist
()),
fields
.
TfExampleFields
.
object_bbox_xmin
:
self
.
_FloatFeature
(
object_bb
[:,
1
].
tolist
()),
fields
.
TfExampleFields
.
object_bbox_ymax
:
self
.
_FloatFeature
(
object_bb
[:,
2
].
tolist
()),
fields
.
TfExampleFields
.
object_bbox_xmax
:
self
.
_FloatFeature
(
object_bb
[:,
3
].
tolist
()),
fields
.
TfExampleFields
.
detection_bbox_ymin
:
self
.
_FloatFeature
(
detection_bb
[:,
0
].
tolist
()),
fields
.
TfExampleFields
.
detection_bbox_xmin
:
self
.
_FloatFeature
(
detection_bb
[:,
1
].
tolist
()),
fields
.
TfExampleFields
.
detection_bbox_ymax
:
self
.
_FloatFeature
(
detection_bb
[:,
2
].
tolist
()),
fields
.
TfExampleFields
.
detection_bbox_xmax
:
self
.
_FloatFeature
(
detection_bb
[:,
3
].
tolist
()),
fields
.
TfExampleFields
.
detection_class_label
:
self
.
_Int64Feature
(
detection_class_label
),
fields
.
TfExampleFields
.
detection_score
:
self
.
_FloatFeature
(
detection_score
),
}
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
parser
=
tf_example_parser
.
TfExampleDetectionAndGTParser
()
results_dict
=
parser
.
parse
(
example
)
self
.
assertIsNone
(
results_dict
)
features
[
fields
.
TfExampleFields
.
object_class_label
]
=
(
self
.
_Int64Feature
(
object_class_label
))
features
[
fields
.
TfExampleFields
.
object_difficult
]
=
(
self
.
_Int64Feature
(
object_difficult
))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
results_dict
=
parser
.
parse
(
example
)
self
.
assertIsNotNone
(
results_dict
)
self
.
assertEqual
(
source_id
,
results_dict
[
fields
.
DetectionResultFields
.
key
])
np_testing
.
assert_almost_equal
(
object_bb
,
results_dict
[
fields
.
InputDataFields
.
groundtruth_boxes
])
np_testing
.
assert_almost_equal
(
detection_bb
,
results_dict
[
fields
.
DetectionResultFields
.
detection_boxes
])
np_testing
.
assert_almost_equal
(
detection_score
,
results_dict
[
fields
.
DetectionResultFields
.
detection_scores
])
np_testing
.
assert_almost_equal
(
detection_class_label
,
results_dict
[
fields
.
DetectionResultFields
.
detection_classes
])
np_testing
.
assert_almost_equal
(
object_difficult
,
results_dict
[
fields
.
InputDataFields
.
groundtruth_difficult
])
np_testing
.
assert_almost_equal
(
object_class_label
,
results_dict
[
fields
.
InputDataFields
.
groundtruth_classes
])
parser
=
tf_example_parser
.
TfExampleDetectionAndGTParser
()
features
[
fields
.
TfExampleFields
.
object_group_of
]
=
(
self
.
_Int64Feature
(
object_group_of
))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
results_dict
=
parser
.
parse
(
example
)
self
.
assertIsNotNone
(
results_dict
)
np_testing
.
assert_almost_equal
(
object_group_of
,
results_dict
[
fields
.
InputDataFields
.
groundtruth_group_of
])
def
testParseString
(
self
):
string_val
=
'abc'
features
=
{
'string'
:
self
.
_BytesFeature
(
string_val
)}
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
parser
=
tf_example_parser
.
StringParser
(
'string'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNotNone
(
result
)
self
.
assertEqual
(
result
,
string_val
)
parser
=
tf_example_parser
.
StringParser
(
'another_string'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNone
(
result
)
def
testParseFloat
(
self
):
float_array_val
=
[
1.5
,
1.4
,
2.0
]
features
=
{
'floats'
:
self
.
_FloatFeature
(
float_array_val
)}
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
parser
=
tf_example_parser
.
FloatParser
(
'floats'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNotNone
(
result
)
np_testing
.
assert_almost_equal
(
result
,
float_array_val
)
parser
=
tf_example_parser
.
StringParser
(
'another_floats'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNone
(
result
)
def
testInt64Parser
(
self
):
int_val
=
[
1
,
2
,
3
]
features
=
{
'ints'
:
self
.
_Int64Feature
(
int_val
)}
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
parser
=
tf_example_parser
.
Int64Parser
(
'ints'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNotNone
(
result
)
np_testing
.
assert_almost_equal
(
result
,
int_val
)
parser
=
tf_example_parser
.
Int64Parser
(
'another_ints'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNone
(
result
)
def
testBoundingBoxParser
(
self
):
bounding_boxes
=
np
.
array
([[
0.0
,
0.5
,
0.3
],
[
0.0
,
0.1
,
0.6
],
[
1.0
,
0.6
,
0.8
],
[
1.0
,
0.6
,
0.7
]]).
transpose
()
features
=
{
'ymin'
:
self
.
_FloatFeature
(
bounding_boxes
[:,
0
]),
'xmin'
:
self
.
_FloatFeature
(
bounding_boxes
[:,
1
]),
'ymax'
:
self
.
_FloatFeature
(
bounding_boxes
[:,
2
]),
'xmax'
:
self
.
_FloatFeature
(
bounding_boxes
[:,
3
])
}
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
parser
=
tf_example_parser
.
BoundingBoxParser
(
'xmin'
,
'ymin'
,
'xmax'
,
'ymax'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNotNone
(
result
)
np_testing
.
assert_almost_equal
(
result
,
bounding_boxes
)
parser
=
tf_example_parser
.
BoundingBoxParser
(
'xmin'
,
'ymin'
,
'xmax'
,
'another_ymax'
)
result
=
parser
.
parse
(
example
)
self
.
assertIsNone
(
result
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment