Commit b152ed9c authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix

parent 2bd53cf1
......@@ -60,6 +60,7 @@ import sys
import zipfile
import collections
import glob
import hashlib
from absl import app
from absl import flags
......@@ -69,8 +70,10 @@ from six.moves import urllib
import tensorflow.compat.v1 as tf
import cv2
from object_detection.utils import dataset_util
from object_detection.dataset_tools import seq_example_util
from object_detection.protos import string_int_label_map_pb2
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
POSSIBLE_TIMESTAMPS = range(902, 1798)
ANNOTATION_URL = "https://research.google.com/ava/download/ava_v2.2.zip"
......@@ -116,7 +119,8 @@ class Ava(object):
splits_to_process="train,val,test",
video_path_format_string=None,
seconds_per_sequence=10,
hop_between_sequences=10):
hop_between_sequences=10,
examples_for_context=False):
"""Downloads data and generates sharded TFRecords.
Downloads the data files, generates metadata, and processes the metadata
......@@ -133,11 +137,15 @@ class Ava(object):
hop_between_sequences: The gap between the centers of
successive sequences.
"""
example_function = self._generate_sequence_examples
if examples_for_context:
example_function = self._generate_examples
logging.info("Downloading data.")
download_output = self._download_data()
for key in splits_to_process.split(","):
logging.info("Generating examples for split: %s", key)
all_metadata = list(self._generate_examples(
all_metadata = list(example_function(
download_output[0][key][0], download_output[0][key][1],
download_output[1], seconds_per_sequence, hop_between_sequences,
video_path_format_string))
......@@ -155,7 +163,7 @@ class Ava(object):
writers[i % len(writers)].write(seq_ex.SerializeToString())
logging.info("Data extraction complete.")
def _generate_examples(self, annotation_file, excluded_file, label_map,
def _generate_sequence_examples(self, annotation_file, excluded_file, label_map,
seconds_per_sequence, hop_between_sequences,
video_path_format_string):
"""For each row in the annotation CSV, generates the corresponding examples.
......@@ -275,6 +283,154 @@ class Ava(object):
cur_vid.release()
def _generate_examples(self, annotation_file, excluded_file, label_map,
seconds_per_sequence, hop_between_sequences,
video_path_format_string):
"""For each row in the annotation CSV, generates the corresponding
examples. When iterating through frames for a single example, skips
over excluded frames. Generates equal-length sequence examples, each with
length seconds_per_sequence (1 fps) and gaps of hop_between_sequences
frames (and seconds) between them, possible greater due to excluded frames.
Args:
annotation_file: path to the file of AVA CSV annotations.
excluded_path: path to a CSV file of excluded timestamps for each video.
label_map: an {int: string} label map.
seconds_per_sequence: The number of seconds per example in each example.
hop_between_sequences: The hop between sequences. If less than
seconds_per_sequence, will overlap.
Yields:
Each prepared tf.Example of metadata also containing video frames
"""
fieldnames = ["id", "timestamp_seconds", "xmin", "ymin", "xmax", "ymax",
"action_label"]
frame_excluded = {}
# create a sparse, nested map of videos and frame indices.
with open(excluded_file, "r") as excluded:
reader = csv.reader(excluded)
for row in reader:
frame_excluded[(row[0], int(float(row[1])))] = True
with open(annotation_file, "r") as annotations:
reader = csv.DictReader(annotations, fieldnames)
frame_annotations = collections.defaultdict(list)
ids = set()
# aggreggate by video and timestamp:
for row in reader:
ids.add(row["id"])
key = (row["id"], int(float(row["timestamp_seconds"])))
frame_annotations[key].append(row)
# for each video, find aggreggates near each sampled frame.:
logging.info("Generating metadata...")
media_num = 1
for media_id in ids:
logging.info("%d/%d, ignore warnings.\n" % (media_num, len(ids)))
media_num += 1
filepath = glob.glob(
video_path_format_string.format(media_id) + "*")[0]
filename = filepath.split("/")[-1]
cur_vid = cv2.VideoCapture(filepath)
width = cur_vid.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cur_vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
middle_frame_time = POSSIBLE_TIMESTAMPS[0]
total_non_excluded = 0;
while middle_frame_time < POSSIBLE_TIMESTAMPS[-1]:
if (media_id, middle_frame_time) not in frame_excluded:
total_non_excluded += 1
middle_frame_time += 1
middle_frame_time = POSSIBLE_TIMESTAMPS[0]
cur_frame_num = 0
while middle_frame_time < POSSIBLE_TIMESTAMPS[-1]:
cur_vid.set(cv2.CAP_PROP_POS_MSEC,
(middle_frame_time) * SECONDS_TO_MILLI)
success, image = cur_vid.read()
success, buffer = cv2.imencode('.jpg', image)
bufstring = buffer.tostring()
if (media_id, middle_frame_time) in frame_excluded:
middle_frame_time += 1
logging.info("Ignoring and skipping excluded frame.")
continue
cur_frame_num += 1
source_id = str(middle_frame_time) + "_" + media_id
xmins = []
xmaxs = []
ymins = []
ymaxs = []
areas = []
labels = []
label_strings = []
confidences = []
for row in frame_annotations[(media_id, middle_frame_time)]:
if len(row) > 2 and int(row["action_label"]) in label_map:
xmins.append(float(row["xmin"]))
xmaxs.append(float(row["xmax"]))
ymins.append(float(row["ymin"]))
ymaxs.append(float(row["ymax"]))
areas.append(float((xmaxs[-1] - xmins[-1]) *
(ymaxs[-1] - ymins[-1])) / 2)
labels.append(int(row["action_label"]))
label_strings.append(label_map[int(row["action_label"])])
confidences.append(1)
else:
logging.warning("Unknown label: %s", row["action_label"])
middle_frame_time += 1/3
if abs(middle_frame_time - round(middle_frame_time) < 0.0001):
middle_frame_time = round(middle_frame_time)
key = hashlib.sha256(bufstring).hexdigest()
date_captured_feature = ("2020-06-17 00:%02d:%02d" % ((middle_frame_time - 900)*3 // 60, (middle_frame_time - 900)*3 % 60))
context_feature_dict = {
'image/height':
dataset_util.int64_feature(int(height)),
'image/width':
dataset_util.int64_feature(int(width)),
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/source_id':
dataset_util.bytes_feature(source_id.encode("utf8")),
'image/filename':
dataset_util.bytes_feature(source_id.encode("utf8")),
'image/encoded':
dataset_util.bytes_feature(bufstring),
'image/key/sha256':
dataset_util.bytes_feature(key.encode('utf8')),
'image/object/bbox/xmin':
dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax':
dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin':
dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(ymaxs),
'image/object/area':
dataset_util.float_list_feature(areas),
'image/object/class/label':
dataset_util.int64_list_feature(labels),
'image/object/class/text':
dataset_util.bytes_list_feature(label_strings),
'image/location':
dataset_util.bytes_feature(media_id.encode('utf8')),
'image/date_captured':
dataset_util.bytes_feature(date_captured_feature.encode('utf8')),
'image/seq_num_frames':
dataset_util.int64_feature(total_non_excluded),
'image/seq_frame_num':
dataset_util.int64_feature(cur_frame_num),
'image/seq_id':
dataset_util.bytes_feature(media_id.encode('utf8')),
}
yield tf.train.Example(
features=tf.train.Features(feature=context_feature_dict))
cur_vid.release()
def _download_data(self):
"""Downloads and extracts data if not already available."""
if sys.version_info >= (3, 0):
......@@ -300,14 +456,27 @@ class Ava(object):
SPLITS[split]["excluded-csv"] = excluded_csv_path
paths[split] = (csv_path, excluded_csv_path)
label_map = self.get_label_map(os.path.join(self.path_to_data_download,
"ava_action_list_v2.2.pbtxt"))
#label_map = self.get_label_map(os.path.join(self.path_to_data_download, "ava_action_list_v2.2.pbtxt"))
#
label_map = self.get_label_map("object_detection/data/mscoco_label_map.pbtxt")
return paths, label_map
def get_label_map(self, path):
"""Parsess a label map into {integer:string} format."""
"""Parses a label map into {integer:string} format."""
label_map = {}
with open(path, "r") as f:
label_map = label_map_util.load_labelmap(path)
print(label_map)
label_map_dict = {}
for item in label_map.item:
label_map_dict[item.name] = item.label_id
with open(path, "rb") as f:
#label_map_util.load_labelmap()
#label_map_str = f.read()
#print(str(label_map_str))
#label_map = string_int_label_map_pb2.StringIntLabelMap()
#label_map.ParseFromString(label_map_str)
pass
"""
current_id = -1
current_label = ""
for line in f:
......@@ -322,16 +491,12 @@ class Ava(object):
if "id:" in line:
current_id = int(line.split()[1])
if "}" in line:
label_map[current_id] = bytes23(current_label)
logging.info(label_map)
label_map[current_id] = bytes(current_label, "utf8")"""
print('label map dict')
logging.info(label_map_dict)
assert len(label_map) == NUM_CLASSES
return label_map
def bytes23(string):
"""Creates a bytes string in either Python 2 or 3."""
if sys.version_info >= (3, 0):
return bytes(string, "utf8")
return bytes(string)
@contextlib.contextmanager
def _close_on_exit(writers):
......@@ -350,7 +515,8 @@ def main(argv):
flags.FLAGS.splits_to_process,
flags.FLAGS.video_path_format_string,
flags.FLAGS.seconds_per_sequence,
flags.FLAGS.hop_between_sequences)
flags.FLAGS.hop_between_sequences,
flags.FLAGS.examples_for_context)
if __name__ == "__main__":
flags.DEFINE_string("path_to_download_data",
......@@ -375,4 +541,8 @@ if __name__ == "__main__":
10,
"The hop between sequences. If less than "
"seconds_per_sequence, will overlap.")
flags.DEFINE_boolean("examples_for_context",
False,
"Whether to generate examples instead of sequence examples. "
"If true, will generate tf.Example objects for use in Context R-CNN.")
app.run(main)
......@@ -152,7 +152,7 @@ def load_labelmap(path):
Returns:
a StringIntLabelMapProto
"""
with tf.io.gfile.GFile(path, 'r') as fid:
with tf.io.gfile.GFile(path, 'rb') as fid:
label_map_string = fid.read()
label_map = string_int_label_map_pb2.StringIntLabelMap()
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment