Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
06d67489
Commit
06d67489
authored
Jul 13, 2020
by
Kaushik Shivakumar
Browse files
fix anotehr file
parent
d745e052
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
419 additions
and
0 deletions
+419
-0
research/object_detection/dataset_tools/context_rcnn/generate_embedding_data.py
...ion/dataset_tools/context_rcnn/generate_embedding_data.py
+419
-0
No files found.
research/object_detection/dataset_tools/context_rcnn/generate_embedding_data.py
0 → 100644
View file @
06d67489
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""A Beam job to generate embedding data for camera trap images.
This tool runs inference with an exported Object Detection model in
`saved_model` format and produce raw embeddings for camera trap data. These
embeddings contain an object-centric feature embedding from Faster R-CNN, the
datetime that the image was taken (normalized in a specific way), and the
position of the object of interest. By default, only the highest-scoring object
embedding is included.
Steps to generate a embedding dataset:
1. Use object_detection/export_inference_graph.py to get a Faster R-CNN
`saved_model` for inference. The input node must accept a tf.Example proto.
2. Run this tool with `saved_model` from step 1 and an TFRecord of tf.Example
protos containing images for inference.
Example Usage:
--------------
python tensorflow_models/object_detection/export_inference_graph.py \
--alsologtostderr \
--input_type tf_example \
--pipeline_config_path path/to/faster_rcnn_model.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory \
--additional_output_tensor_names detection_features
python generate_embedding_data.py \
--alsologtostderr \
--embedding_input_tfrecord path/to/input_tfrecords* \
--embedding_output_tfrecord path/to/output_tfrecords \
--embedding_model_dir path/to/exported_model_directory/saved_model
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
datetime
import
os
import
threading
import
numpy
as
np
import
six
import
tensorflow.compat.v1
as
tf
try
:
import
apache_beam
as
beam
# pylint:disable=g-import-not-at-top
except
ModuleNotFoundError
:
pass
class
GenerateEmbeddingDataFn
(
beam
.
DoFn
):
"""Generates embedding data for camera trap images.
This Beam DoFn performs inference with an object detection `saved_model` and
produces contextual embedding vectors.
"""
session_lock
=
threading
.
Lock
()
def
__init__
(
self
,
model_dir
,
top_k_embedding_count
,
bottom_k_embedding_count
):
"""Initialization function.
Args:
model_dir: A directory containing saved model.
top_k_embedding_count: the number of high-confidence embeddings to store
bottom_k_embedding_count: the number of low-confidence embeddings to store
"""
self
.
_model_dir
=
model_dir
self
.
_session
=
None
self
.
_num_examples_processed
=
beam
.
metrics
.
Metrics
.
counter
(
'embedding_data_generation'
,
'num_tf_examples_processed'
)
self
.
_top_k_embedding_count
=
top_k_embedding_count
self
.
_bottom_k_embedding_count
=
bottom_k_embedding_count
def
start_bundle
(
self
):
self
.
_load_inference_model
()
def
_load_inference_model
(
self
):
# Because initialization of the tf.Session is expensive we share
# one instance across all threads in the worker. This is possible since
# tf.Session.run() is thread safe.
with
self
.
session_lock
:
if
self
.
_session
is
None
:
graph
=
tf
.
Graph
()
self
.
_session
=
tf
.
Session
(
graph
=
graph
)
with
graph
.
as_default
():
meta_graph
=
tf
.
saved_model
.
loader
.
load
(
self
.
_session
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
self
.
_model_dir
)
signature
=
meta_graph
.
signature_def
[
'serving_default'
]
input_tensor_name
=
signature
.
inputs
[
'input_tensor'
].
name
detection_features_name
=
signature
.
outputs
[
'detection_features'
].
name
detection_boxes_name
=
signature
.
outputs
[
'detection_boxes'
].
name
num_detections_name
=
signature
.
outputs
[
'num_detections'
].
name
self
.
_input
=
graph
.
get_tensor_by_name
(
input_tensor_name
)
self
.
_embedding_node
=
graph
.
get_tensor_by_name
(
detection_features_name
)
self
.
_box_node
=
graph
.
get_tensor_by_name
(
detection_boxes_name
)
self
.
_scores_node
=
graph
.
get_tensor_by_name
(
signature
.
outputs
[
'detection_scores'
].
name
)
self
.
_num_detections
=
graph
.
get_tensor_by_name
(
num_detections_name
)
tf
.
logging
.
info
(
signature
.
outputs
[
'detection_features'
].
name
)
tf
.
logging
.
info
(
signature
.
outputs
[
'detection_boxes'
].
name
)
tf
.
logging
.
info
(
signature
.
outputs
[
'num_detections'
].
name
)
def
process
(
self
,
tfrecord_entry
):
return
self
.
_run_inference_and_generate_embedding
(
tfrecord_entry
)
def
_run_inference_and_generate_embedding
(
self
,
tfrecord_entry
):
input_example
=
tf
.
train
.
Example
.
FromString
(
tfrecord_entry
)
# Convert date_captured datetime string to unix time integer and store
def
get_date_captured
(
example
):
date_captured
=
datetime
.
datetime
.
strptime
(
six
.
ensure_str
(
example
.
features
.
feature
[
'image/date_captured'
].
bytes_list
.
value
[
0
]),
'%Y-%m-%d %H:%M:%S'
)
return
date_captured
try
:
date_captured
=
get_date_captured
(
input_example
)
except
Exception
:
# pylint: disable=broad-except
# we require date_captured to be available for all images
return
[]
def
embed_date_captured
(
date_captured
):
"""Encodes the datetime of the image."""
embedded_date_captured
=
[]
month_max
=
12.0
day_max
=
31.0
hour_max
=
24.0
minute_max
=
60.0
min_year
=
1990.0
max_year
=
2030.0
year
=
(
date_captured
.
year
-
min_year
)
/
float
(
max_year
-
min_year
)
embedded_date_captured
.
append
(
year
)
month
=
(
date_captured
.
month
-
1
)
/
month_max
embedded_date_captured
.
append
(
month
)
day
=
(
date_captured
.
day
-
1
)
/
day_max
embedded_date_captured
.
append
(
day
)
hour
=
date_captured
.
hour
/
hour_max
embedded_date_captured
.
append
(
hour
)
minute
=
date_captured
.
minute
/
minute_max
embedded_date_captured
.
append
(
minute
)
return
np
.
asarray
(
embedded_date_captured
)
def
embed_position_and_size
(
box
):
"""Encodes the bounding box of the object of interest."""
ymin
=
box
[
0
]
xmin
=
box
[
1
]
ymax
=
box
[
2
]
xmax
=
box
[
3
]
w
=
xmax
-
xmin
h
=
ymax
-
ymin
x
=
xmin
+
w
/
2.0
y
=
ymin
+
h
/
2.0
return
np
.
asarray
([
x
,
y
,
w
,
h
])
unix_time
=
(
(
date_captured
-
datetime
.
datetime
.
fromtimestamp
(
0
)).
total_seconds
())
example
=
tf
.
train
.
Example
()
example
.
features
.
feature
[
'image/unix_time'
].
float_list
.
value
.
extend
(
[
unix_time
])
(
detection_features
,
detection_boxes
,
num_detections
,
detection_scores
)
=
self
.
_session
.
run
(
[
self
.
_embedding_node
,
self
.
_box_node
,
self
.
_num_detections
[
0
],
self
.
_scores_node
],
feed_dict
=
{
self
.
_input
:
[
tfrecord_entry
]})
num_detections
=
int
(
num_detections
)
embed_all
=
[]
score_all
=
[]
detection_features
=
np
.
asarray
(
detection_features
)
def
get_bb_embedding
(
detection_features
,
detection_boxes
,
detection_scores
,
index
):
embedding
=
detection_features
[
0
][
index
]
pooled_embedding
=
np
.
mean
(
np
.
mean
(
embedding
,
axis
=
1
),
axis
=
0
)
box
=
detection_boxes
[
0
][
index
]
position_embedding
=
embed_position_and_size
(
box
)
score
=
detection_scores
[
0
][
index
]
return
np
.
concatenate
((
pooled_embedding
,
position_embedding
)),
score
temporal_embedding
=
embed_date_captured
(
date_captured
)
embedding_count
=
0
for
index
in
range
(
min
(
num_detections
,
self
.
_top_k_embedding_count
)):
bb_embedding
,
score
=
get_bb_embedding
(
detection_features
,
detection_boxes
,
detection_scores
,
index
)
embed_all
.
extend
(
bb_embedding
)
embed_all
.
extend
(
temporal_embedding
)
score_all
.
append
(
score
)
embedding_count
+=
1
for
index
in
range
(
max
(
0
,
num_detections
-
1
),
max
(
-
1
,
num_detections
-
1
-
self
.
_bottom_k_embedding_count
),
-
1
):
bb_embedding
,
score
=
get_bb_embedding
(
detection_features
,
detection_boxes
,
detection_scores
,
index
)
embed_all
.
extend
(
bb_embedding
)
embed_all
.
extend
(
temporal_embedding
)
score_all
.
append
(
score
)
embedding_count
+=
1
if
embedding_count
==
0
:
bb_embedding
,
score
=
get_bb_embedding
(
detection_features
,
detection_boxes
,
detection_scores
,
0
)
embed_all
.
extend
(
bb_embedding
)
embed_all
.
extend
(
temporal_embedding
)
score_all
.
append
(
score
)
# Takes max in case embedding_count is 0.
embedding_length
=
len
(
embed_all
)
//
max
(
1
,
embedding_count
)
embed_all
=
np
.
asarray
(
embed_all
)
example
.
features
.
feature
[
'image/embedding'
].
float_list
.
value
.
extend
(
embed_all
)
example
.
features
.
feature
[
'image/embedding_score'
].
float_list
.
value
.
extend
(
score_all
)
example
.
features
.
feature
[
'image/embedding_length'
].
int64_list
.
value
.
append
(
embedding_length
)
example
.
features
.
feature
[
'image/embedding_count'
].
int64_list
.
value
.
append
(
embedding_count
)
# Add other essential example attributes
example
.
features
.
feature
[
'image/encoded'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/encoded'
].
bytes_list
.
value
)
example
.
features
.
feature
[
'image/height'
].
int64_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/height'
].
int64_list
.
value
)
example
.
features
.
feature
[
'image/width'
].
int64_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/width'
].
int64_list
.
value
)
example
.
features
.
feature
[
'image/source_id'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/source_id'
].
bytes_list
.
value
)
example
.
features
.
feature
[
'image/location'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/location'
].
bytes_list
.
value
)
example
.
features
.
feature
[
'image/date_captured'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/date_captured'
].
bytes_list
.
value
)
example
.
features
.
feature
[
'image/class/text'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/class/text'
].
bytes_list
.
value
)
example
.
features
.
feature
[
'image/class/label'
].
int64_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/class/label'
].
int64_list
.
value
)
example
.
features
.
feature
[
'image/seq_id'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/seq_id'
].
bytes_list
.
value
)
example
.
features
.
feature
[
'image/seq_num_frames'
].
int64_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/seq_num_frames'
].
int64_list
.
value
)
example
.
features
.
feature
[
'image/seq_frame_num'
].
int64_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/seq_frame_num'
].
int64_list
.
value
)
example
.
features
.
feature
[
'image/object/bbox/ymax'
].
float_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/bbox/ymax'
].
float_list
.
value
)
example
.
features
.
feature
[
'image/object/bbox/ymin'
].
float_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/bbox/ymin'
].
float_list
.
value
)
example
.
features
.
feature
[
'image/object/bbox/xmax'
].
float_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/bbox/xmax'
].
float_list
.
value
)
example
.
features
.
feature
[
'image/object/bbox/xmin'
].
float_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/bbox/xmin'
].
float_list
.
value
)
example
.
features
.
feature
[
'image/object/class/score'
].
float_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/class/score'
].
float_list
.
value
)
example
.
features
.
feature
[
'image/object/class/label'
].
int64_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/class/label'
].
int64_list
.
value
)
example
.
features
.
feature
[
'image/object/class/text'
].
bytes_list
.
value
.
extend
(
input_example
.
features
.
feature
[
'image/object/class/text'
].
bytes_list
.
value
)
self
.
_num_examples_processed
.
inc
(
1
)
return
[
example
]
def
construct_pipeline
(
pipeline
,
input_tfrecord
,
output_tfrecord
,
model_dir
,
top_k_embedding_count
,
bottom_k_embedding_count
,
num_shards
):
"""Returns a beam pipeline to run object detection inference.
Args:
pipeline: Initialized beam pipeline.
input_tfrecord: An TFRecord of tf.train.Example protos containing images.
output_tfrecord: An TFRecord of tf.train.Example protos that contain images
in the input TFRecord and the detections from the model.
model_dir: Path to `saved_model` to use for inference.
top_k_embedding_count: The number of high-confidence embeddings to store.
bottom_k_embedding_count: The number of low-confidence embeddings to store.
num_shards: The number of output shards.
"""
input_collection
=
(
pipeline
|
'ReadInputTFRecord'
>>
beam
.
io
.
tfrecordio
.
ReadFromTFRecord
(
input_tfrecord
,
coder
=
beam
.
coders
.
BytesCoder
()))
output_collection
=
input_collection
|
'ExtractEmbedding'
>>
beam
.
ParDo
(
GenerateEmbeddingDataFn
(
model_dir
,
top_k_embedding_count
,
bottom_k_embedding_count
))
output_collection
=
output_collection
|
'Reshuffle'
>>
beam
.
Reshuffle
()
_
=
output_collection
|
'WritetoDisk'
>>
beam
.
io
.
tfrecordio
.
WriteToTFRecord
(
output_tfrecord
,
num_shards
=
num_shards
,
coder
=
beam
.
coders
.
ProtoCoder
(
tf
.
train
.
Example
))
def
parse_args
(
argv
):
"""Command-line argument parser.
Args:
argv: command line arguments
Returns:
beam_args: Arguments for the beam pipeline.
pipeline_args: Arguments for the pipeline options, such as runner type.
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--embedding_input_tfrecord'
,
dest
=
'embedding_input_tfrecord'
,
required
=
True
,
help
=
'TFRecord containing images in tf.Example format for object '
'detection.'
)
parser
.
add_argument
(
'--embedding_output_tfrecord'
,
dest
=
'embedding_output_tfrecord'
,
required
=
True
,
help
=
'TFRecord containing embeddings in tf.Example format.'
)
parser
.
add_argument
(
'--embedding_model_dir'
,
dest
=
'embedding_model_dir'
,
required
=
True
,
help
=
'Path to directory containing an object detection SavedModel with'
'detection_box_classifier_features in the output.'
)
parser
.
add_argument
(
'--top_k_embedding_count'
,
dest
=
'top_k_embedding_count'
,
default
=
1
,
help
=
'The number of top k embeddings to add to the memory bank.'
)
parser
.
add_argument
(
'--bottom_k_embedding_count'
,
dest
=
'bottom_k_embedding_count'
,
default
=
0
,
help
=
'The number of bottom k embeddings to add to the memory bank.'
)
parser
.
add_argument
(
'--num_shards'
,
dest
=
'num_shards'
,
default
=
0
,
help
=
'Number of output shards.'
)
beam_args
,
pipeline_args
=
parser
.
parse_known_args
(
argv
)
return
beam_args
,
pipeline_args
def
main
(
argv
=
None
,
save_main_session
=
True
):
"""Runs the Beam pipeline that performs inference.
Args:
argv: Command line arguments.
save_main_session: Whether to save the main session.
"""
args
,
pipeline_args
=
parse_args
(
argv
)
pipeline_options
=
beam
.
options
.
pipeline_options
.
PipelineOptions
(
pipeline_args
)
pipeline_options
.
view_as
(
beam
.
options
.
pipeline_options
.
SetupOptions
).
save_main_session
=
(
save_main_session
)
dirname
=
os
.
path
.
dirname
(
args
.
embedding_output_tfrecord
)
tf
.
io
.
gfile
.
makedirs
(
dirname
)
p
=
beam
.
Pipeline
(
options
=
pipeline_options
)
construct_pipeline
(
p
,
args
.
embedding_input_tfrecord
,
args
.
embedding_output_tfrecord
,
args
.
embedding_model_dir
,
args
.
top_k_embedding_count
,
args
.
bottom_k_embedding_count
,
args
.
num_shards
)
p
.
run
()
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment