Commit 0eabf192 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

ughh

parent 5a2cf36f
...@@ -133,7 +133,6 @@ class Ava(object): ...@@ -133,7 +133,6 @@ class Ava(object):
hop_between_sequences: The gap between the centers of hop_between_sequences: The gap between the centers of
successive sequences. successive sequences.
""" """
global_source_id = 0
logging.info("Downloading data.") logging.info("Downloading data.")
download_output = self._download_data() download_output = self._download_data()
for key in splits_to_process.split(","): for key in splits_to_process.split(","):
...@@ -141,7 +140,7 @@ class Ava(object): ...@@ -141,7 +140,7 @@ class Ava(object):
all_metadata = list(self._generate_examples( all_metadata = list(self._generate_examples(
download_output[0][key][0], download_output[0][key][1], download_output[0][key][0], download_output[0][key][1],
download_output[1], seconds_per_sequence, hop_between_sequences, download_output[1], seconds_per_sequence, hop_between_sequences,
video_path_format_string, global_source_id)) video_path_format_string))
logging.info("An example of the metadata: ") logging.info("An example of the metadata: ")
logging.info(all_metadata[0]) logging.info(all_metadata[0])
random.seed(47) random.seed(47)
...@@ -177,7 +176,6 @@ class Ava(object): ...@@ -177,7 +176,6 @@ class Ava(object):
Yields: Yields:
Each prepared tf.SequenceExample of metadata also containing video frames Each prepared tf.SequenceExample of metadata also containing video frames
""" """
global GLOBAL_SOURCE_ID
fieldnames = ["id", "timestamp_seconds", "xmin", "ymin", "xmax", "ymax", fieldnames = ["id", "timestamp_seconds", "xmin", "ymin", "xmax", "ymax",
"action_label"] "action_label"]
frame_excluded = {} frame_excluded = {}
...@@ -199,6 +197,8 @@ class Ava(object): ...@@ -199,6 +197,8 @@ class Ava(object):
logging.info("Generating metadata...") logging.info("Generating metadata...")
media_num = 1 media_num = 1
for media_id in ids: for media_id in ids:
if media_num > 2:
continue
logging.info("%d/%d, ignore warnings.\n" % (media_num, len(ids))) logging.info("%d/%d, ignore warnings.\n" % (media_num, len(ids)))
media_num += 1 media_num += 1
...@@ -213,7 +213,6 @@ class Ava(object): ...@@ -213,7 +213,6 @@ class Ava(object):
0 if seconds_per_sequence % 2 == 0 else 1) 0 if seconds_per_sequence % 2 == 0 else 1)
end_time = middle_frame_time + (seconds_per_sequence // 2) end_time = middle_frame_time + (seconds_per_sequence // 2)
GLOBAL_SOURCE_ID += 1
total_xmins = [] total_xmins = []
total_xmaxs = [] total_xmaxs = []
total_ymins = [] total_ymins = []
...@@ -239,12 +238,10 @@ class Ava(object): ...@@ -239,12 +238,10 @@ class Ava(object):
_, buffer = cv2.imencode('.jpg', image) _, buffer = cv2.imencode('.jpg', image)
bufstring = buffer.tostring() bufstring = buffer.tostring()
total_images.append(dataset_util.bytes_feature(bufstring)) total_images.append(bufstring)
source_id = str(GLOBAL_SOURCE_ID) + "_" + media_id source_id = str(windowed_timestamp) + "_" + media_id
total_source_ids.append(dataset_util.bytes_feature( total_source_ids.append(source_id)
source_id.encode("utf8"))) total_is_annotated.append(1)
total_is_annotated.append(dataset_util.int64_feature(1))
GLOBAL_SOURCE_ID += 1
xmins = [] xmins = []
xmaxs = [] xmaxs = []
...@@ -265,54 +262,19 @@ class Ava(object): ...@@ -265,54 +262,19 @@ class Ava(object):
else: else:
logging.warning("Unknown label: %s", row["action_label"]) logging.warning("Unknown label: %s", row["action_label"])
total_xmins.append(dataset_util.float_list_feature(xmins)) total_xmins.append(xmins)
total_xmaxs.append(dataset_util.float_list_feature(xmaxs)) total_xmaxs.append(xmaxs)
total_ymins.append(dataset_util.float_list_feature(ymins)) total_ymins.append(ymins)
total_ymaxs.append(dataset_util.float_list_feature(ymaxs)) total_ymaxs.append(ymaxs)
total_labels.append(dataset_util.int64_list_feature(labels)) total_labels.append(labels)
total_label_strings.append( total_label_strings.append(label_strings)
dataset_util.bytes_list_feature(label_strings)) total_confidences.append(confidences)
total_confidences.append(
dataset_util.float_list_feature(confidences))
windowed_timestamp += 1 windowed_timestamp += 1
context_feature_dict = {
'image/height':
dataset_util.int64_feature(int(height)),
'image/width':
dataset_util.int64_feature(int(width)),
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
}
sequence_feature_dict = {
'image/source_id':
feature_list_feature(total_source_ids),
'image/encoded':
feature_list_feature(total_images),
'region/bbox/xmin':
feature_list_feature(total_xmins),
'region/bbox/xmax':
feature_list_feature(total_xmaxs),
'region/bbox/ymin':
feature_list_feature(total_ymins),
'region/bbox/ymax':
feature_list_feature(total_ymaxs),
'region/label/index':
feature_list_feature(total_labels),
'region/label/string':
feature_list_feature(total_label_strings),
'region/label/confidence':
feature_list_feature(total_confidences), #all ones
'region/is_annotated':
feature_list_feature(total_is_annotated) #all ones
}
if len(total_xmins) > 0: if len(total_xmins) > 0:
yield tf.train.SequenceExample( yield seq_example_util.make_sequence_example("AVA", media_id, total_images,
context=tf.train.Features(feature=context_feature_dict), int(height), int(width), 'jpeg', total_source_ids, None, total_is_annotated,
feature_lists=tf.train.FeatureLists( [list(z) for z in zip(ymins, xmins, ymaxs, xmaxs)], total_label_strings)
feature_list=sequence_feature_dict))
#Move middle_time_frame, skipping excluded frames #Move middle_time_frame, skipping excluded frames
frames_mv = 0 frames_mv = 0
......
...@@ -138,7 +138,7 @@ def boxes_to_box_components(bboxes): ...@@ -138,7 +138,7 @@ def boxes_to_box_components(bboxes):
xmax_list = [] xmax_list = []
for bbox in bboxes: for bbox in bboxes:
bbox = np.array(bbox).astype(np.float32) bbox = np.array(bbox).astype(np.float32)
ymin, xmin, ymax, xmax = np.split(bbox, 4, axis=1) ymin, xmin, ymax, xmax = np.split(bbox, 4, axis=0)
ymin_list.append(np.reshape(ymin, [-1])) ymin_list.append(np.reshape(ymin, [-1]))
xmin_list.append(np.reshape(xmin, [-1])) xmin_list.append(np.reshape(xmin, [-1]))
ymax_list.append(np.reshape(ymax, [-1])) ymax_list.append(np.reshape(ymax, [-1]))
......
...@@ -133,7 +133,11 @@ DETECTION_MODULE_MAP = { ...@@ -133,7 +133,11 @@ DETECTION_MODULE_MAP = {
def export_inference_graph(input_type, def export_inference_graph(input_type,
pipeline_config, pipeline_config,
trained_checkpoint_dir, trained_checkpoint_dir,
output_directory): output_directory,
use_side_inputs,
side_input_shapes,
side_input_types,
side_input_names):
"""Exports inference graph for the model specified in the pipeline config. """Exports inference graph for the model specified in the pipeline config.
This function creates `output_directory` if it does not already exist, This function creates `output_directory` if it does not already exist,
......
...@@ -106,6 +106,27 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.') ...@@ -106,6 +106,27 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('config_override', '', flags.DEFINE_string('config_override', '',
'pipeline_pb2.TrainEvalPipelineConfig ' 'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.') 'text proto to override pipeline_config_path.')
flags.DEFINE_boolean('use_side_inputs', False,
'If True, uses side inputs as well as image inputs.')
flags.DEFINE_string('side_input_shapes', None,
'If use_side_inputs is True, this explicitly sets '
'the shape of the side input tensors to a fixed size. The '
'dimensions are to be provided as a comma-separated list '
'of integers. A value of -1 can be used for unknown '
'dimensions. A `/` denotes a break, starting the shape of '
'the next side input tensor. This flag is required if '
'using side inputs.')
flags.DEFINE_string('side_input_types', None,
'If use_side_inputs is True, this explicitly sets '
'the type of the side input tensors. The '
'dimensions are to be provided as a comma-separated list '
'of types, each of `string`, `integer`, or `float`. '
'This flag is required if using side inputs.')
flags.DEFINE_string('side_input_names', None,
'If use_side_inputs is True, this explicitly sets '
'the names of the side input tensors required by the model '
'assuming the names will be a comma-separated list of '
'strings. This flag is required if using side inputs.')
flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_dir') flags.mark_flag_as_required('trained_checkpoint_dir')
...@@ -119,7 +140,8 @@ def main(_): ...@@ -119,7 +140,8 @@ def main(_):
text_format.Merge(FLAGS.config_override, pipeline_config) text_format.Merge(FLAGS.config_override, pipeline_config)
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_dir,
FLAGS.output_directory) FLAGS.output_directory, FLAGS.use_side_inputs, FLAGS.side_input_shapes,
FLAGS.side_input_types, FLAGS.side_input_names)
if __name__ == '__main__': if __name__ == '__main__':
......
"""Setup script for object_detection.""" """Setup script for object_detection with TF2.0."""
import os
from setuptools import find_packages from setuptools import find_packages
from setuptools import setup from setuptools import setup
# Note: adding apache-beam to required packages causes conflict with
REQUIRED_PACKAGES = ['Pillow>=1.0', 'Matplotlib>=2.1', 'Cython>=0.28.1'] # tf-models-offical requirements. These packages request for incompatible
# oauth2client package.
REQUIRED_PACKAGES = ['pillow', 'lxml', 'matplotlib', 'Cython', 'contextlib2',
'tf-slim', 'six', 'pycocotools', 'scipy', 'pandas',
'tf-models-official']
setup( setup(
name='object_detection', name='object_detection',
version='0.1', version='0.1',
install_requires=REQUIRED_PACKAGES, install_requires=REQUIRED_PACKAGES,
include_package_data=True, include_package_data=True,
packages=[p for p in find_packages() if p.startswith('object_detection')], packages=(
[p for p in find_packages() if p.startswith('object_detection')] +
find_packages(where=os.path.join('.', 'slim'))),
package_dir={
'datasets': os.path.join('slim', 'datasets'),
'nets': os.path.join('slim', 'nets'),
'preprocessing': os.path.join('slim', 'preprocessing'),
'deployment': os.path.join('slim', 'deployment'),
'scripts': os.path.join('slim', 'scripts'),
},
description='Tensorflow Object Detection Library', description='Tensorflow Object Detection Library',
python_requires='>3.6',
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment