Commit b152ed9c authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix

parent 2bd53cf1
...@@ -60,6 +60,7 @@ import sys ...@@ -60,6 +60,7 @@ import sys
import zipfile import zipfile
import collections import collections
import glob import glob
import hashlib
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -69,8 +70,10 @@ from six.moves import urllib ...@@ -69,8 +70,10 @@ from six.moves import urllib
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import cv2 import cv2
from object_detection.utils import dataset_util
from object_detection.dataset_tools import seq_example_util from object_detection.dataset_tools import seq_example_util
from object_detection.protos import string_int_label_map_pb2
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
POSSIBLE_TIMESTAMPS = range(902, 1798) POSSIBLE_TIMESTAMPS = range(902, 1798)
ANNOTATION_URL = "https://research.google.com/ava/download/ava_v2.2.zip" ANNOTATION_URL = "https://research.google.com/ava/download/ava_v2.2.zip"
...@@ -116,7 +119,8 @@ class Ava(object): ...@@ -116,7 +119,8 @@ class Ava(object):
splits_to_process="train,val,test", splits_to_process="train,val,test",
video_path_format_string=None, video_path_format_string=None,
seconds_per_sequence=10, seconds_per_sequence=10,
hop_between_sequences=10): hop_between_sequences=10,
examples_for_context=False):
"""Downloads data and generates sharded TFRecords. """Downloads data and generates sharded TFRecords.
Downloads the data files, generates metadata, and processes the metadata Downloads the data files, generates metadata, and processes the metadata
...@@ -133,11 +137,15 @@ class Ava(object): ...@@ -133,11 +137,15 @@ class Ava(object):
hop_between_sequences: The gap between the centers of hop_between_sequences: The gap between the centers of
successive sequences. successive sequences.
""" """
example_function = self._generate_sequence_examples
if examples_for_context:
example_function = self._generate_examples
logging.info("Downloading data.") logging.info("Downloading data.")
download_output = self._download_data() download_output = self._download_data()
for key in splits_to_process.split(","): for key in splits_to_process.split(","):
logging.info("Generating examples for split: %s", key) logging.info("Generating examples for split: %s", key)
all_metadata = list(self._generate_examples( all_metadata = list(example_function(
download_output[0][key][0], download_output[0][key][1], download_output[0][key][0], download_output[0][key][1],
download_output[1], seconds_per_sequence, hop_between_sequences, download_output[1], seconds_per_sequence, hop_between_sequences,
video_path_format_string)) video_path_format_string))
...@@ -155,7 +163,7 @@ class Ava(object): ...@@ -155,7 +163,7 @@ class Ava(object):
writers[i % len(writers)].write(seq_ex.SerializeToString()) writers[i % len(writers)].write(seq_ex.SerializeToString())
logging.info("Data extraction complete.") logging.info("Data extraction complete.")
def _generate_examples(self, annotation_file, excluded_file, label_map, def _generate_sequence_examples(self, annotation_file, excluded_file, label_map,
seconds_per_sequence, hop_between_sequences, seconds_per_sequence, hop_between_sequences,
video_path_format_string): video_path_format_string):
"""For each row in the annotation CSV, generates the corresponding examples. """For each row in the annotation CSV, generates the corresponding examples.
...@@ -275,6 +283,154 @@ class Ava(object): ...@@ -275,6 +283,154 @@ class Ava(object):
cur_vid.release() cur_vid.release()
def _generate_examples(self, annotation_file, excluded_file, label_map,
seconds_per_sequence, hop_between_sequences,
video_path_format_string):
"""For each row in the annotation CSV, generates the corresponding
examples. When iterating through frames for a single example, skips
over excluded frames. Generates equal-length sequence examples, each with
length seconds_per_sequence (1 fps) and gaps of hop_between_sequences
frames (and seconds) between them, possible greater due to excluded frames.
Args:
annotation_file: path to the file of AVA CSV annotations.
excluded_path: path to a CSV file of excluded timestamps for each video.
label_map: an {int: string} label map.
seconds_per_sequence: The number of seconds per example in each example.
hop_between_sequences: The hop between sequences. If less than
seconds_per_sequence, will overlap.
Yields:
Each prepared tf.Example of metadata also containing video frames
"""
fieldnames = ["id", "timestamp_seconds", "xmin", "ymin", "xmax", "ymax",
"action_label"]
frame_excluded = {}
# create a sparse, nested map of videos and frame indices.
with open(excluded_file, "r") as excluded:
reader = csv.reader(excluded)
for row in reader:
frame_excluded[(row[0], int(float(row[1])))] = True
with open(annotation_file, "r") as annotations:
reader = csv.DictReader(annotations, fieldnames)
frame_annotations = collections.defaultdict(list)
ids = set()
# aggreggate by video and timestamp:
for row in reader:
ids.add(row["id"])
key = (row["id"], int(float(row["timestamp_seconds"])))
frame_annotations[key].append(row)
# for each video, find aggreggates near each sampled frame.:
logging.info("Generating metadata...")
media_num = 1
for media_id in ids:
logging.info("%d/%d, ignore warnings.\n" % (media_num, len(ids)))
media_num += 1
filepath = glob.glob(
video_path_format_string.format(media_id) + "*")[0]
filename = filepath.split("/")[-1]
cur_vid = cv2.VideoCapture(filepath)
width = cur_vid.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cur_vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
middle_frame_time = POSSIBLE_TIMESTAMPS[0]
total_non_excluded = 0;
while middle_frame_time < POSSIBLE_TIMESTAMPS[-1]:
if (media_id, middle_frame_time) not in frame_excluded:
total_non_excluded += 1
middle_frame_time += 1
middle_frame_time = POSSIBLE_TIMESTAMPS[0]
cur_frame_num = 0
while middle_frame_time < POSSIBLE_TIMESTAMPS[-1]:
cur_vid.set(cv2.CAP_PROP_POS_MSEC,
(middle_frame_time) * SECONDS_TO_MILLI)
success, image = cur_vid.read()
success, buffer = cv2.imencode('.jpg', image)
bufstring = buffer.tostring()
if (media_id, middle_frame_time) in frame_excluded:
middle_frame_time += 1
logging.info("Ignoring and skipping excluded frame.")
continue
cur_frame_num += 1
source_id = str(middle_frame_time) + "_" + media_id
xmins = []
xmaxs = []
ymins = []
ymaxs = []
areas = []
labels = []
label_strings = []
confidences = []
for row in frame_annotations[(media_id, middle_frame_time)]:
if len(row) > 2 and int(row["action_label"]) in label_map:
xmins.append(float(row["xmin"]))
xmaxs.append(float(row["xmax"]))
ymins.append(float(row["ymin"]))
ymaxs.append(float(row["ymax"]))
areas.append(float((xmaxs[-1] - xmins[-1]) *
(ymaxs[-1] - ymins[-1])) / 2)
labels.append(int(row["action_label"]))
label_strings.append(label_map[int(row["action_label"])])
confidences.append(1)
else:
logging.warning("Unknown label: %s", row["action_label"])
middle_frame_time += 1/3
if abs(middle_frame_time - round(middle_frame_time) < 0.0001):
middle_frame_time = round(middle_frame_time)
key = hashlib.sha256(bufstring).hexdigest()
date_captured_feature = ("2020-06-17 00:%02d:%02d" % ((middle_frame_time - 900)*3 // 60, (middle_frame_time - 900)*3 % 60))
context_feature_dict = {
'image/height':
dataset_util.int64_feature(int(height)),
'image/width':
dataset_util.int64_feature(int(width)),
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/source_id':
dataset_util.bytes_feature(source_id.encode("utf8")),
'image/filename':
dataset_util.bytes_feature(source_id.encode("utf8")),
'image/encoded':
dataset_util.bytes_feature(bufstring),
'image/key/sha256':
dataset_util.bytes_feature(key.encode('utf8')),
'image/object/bbox/xmin':
dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax':
dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin':
dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(ymaxs),
'image/object/area':
dataset_util.float_list_feature(areas),
'image/object/class/label':
dataset_util.int64_list_feature(labels),
'image/object/class/text':
dataset_util.bytes_list_feature(label_strings),
'image/location':
dataset_util.bytes_feature(media_id.encode('utf8')),
'image/date_captured':
dataset_util.bytes_feature(date_captured_feature.encode('utf8')),
'image/seq_num_frames':
dataset_util.int64_feature(total_non_excluded),
'image/seq_frame_num':
dataset_util.int64_feature(cur_frame_num),
'image/seq_id':
dataset_util.bytes_feature(media_id.encode('utf8')),
}
yield tf.train.Example(
features=tf.train.Features(feature=context_feature_dict))
cur_vid.release()
def _download_data(self): def _download_data(self):
"""Downloads and extracts data if not already available.""" """Downloads and extracts data if not already available."""
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
...@@ -300,14 +456,27 @@ class Ava(object): ...@@ -300,14 +456,27 @@ class Ava(object):
SPLITS[split]["excluded-csv"] = excluded_csv_path SPLITS[split]["excluded-csv"] = excluded_csv_path
paths[split] = (csv_path, excluded_csv_path) paths[split] = (csv_path, excluded_csv_path)
label_map = self.get_label_map(os.path.join(self.path_to_data_download, #label_map = self.get_label_map(os.path.join(self.path_to_data_download, "ava_action_list_v2.2.pbtxt"))
"ava_action_list_v2.2.pbtxt")) #
label_map = self.get_label_map("object_detection/data/mscoco_label_map.pbtxt")
return paths, label_map return paths, label_map
def get_label_map(self, path): def get_label_map(self, path):
"""Parsess a label map into {integer:string} format.""" """Parses a label map into {integer:string} format."""
label_map = {} label_map = {}
with open(path, "r") as f: label_map = label_map_util.load_labelmap(path)
print(label_map)
label_map_dict = {}
for item in label_map.item:
label_map_dict[item.name] = item.label_id
with open(path, "rb") as f:
#label_map_util.load_labelmap()
#label_map_str = f.read()
#print(str(label_map_str))
#label_map = string_int_label_map_pb2.StringIntLabelMap()
#label_map.ParseFromString(label_map_str)
pass
"""
current_id = -1 current_id = -1
current_label = "" current_label = ""
for line in f: for line in f:
...@@ -322,16 +491,12 @@ class Ava(object): ...@@ -322,16 +491,12 @@ class Ava(object):
if "id:" in line: if "id:" in line:
current_id = int(line.split()[1]) current_id = int(line.split()[1])
if "}" in line: if "}" in line:
label_map[current_id] = bytes23(current_label) label_map[current_id] = bytes(current_label, "utf8")"""
logging.info(label_map) print('label map dict')
logging.info(label_map_dict)
assert len(label_map) == NUM_CLASSES assert len(label_map) == NUM_CLASSES
return label_map return label_map
def bytes23(string):
"""Creates a bytes string in either Python 2 or 3."""
if sys.version_info >= (3, 0):
return bytes(string, "utf8")
return bytes(string)
@contextlib.contextmanager @contextlib.contextmanager
def _close_on_exit(writers): def _close_on_exit(writers):
...@@ -350,7 +515,8 @@ def main(argv): ...@@ -350,7 +515,8 @@ def main(argv):
flags.FLAGS.splits_to_process, flags.FLAGS.splits_to_process,
flags.FLAGS.video_path_format_string, flags.FLAGS.video_path_format_string,
flags.FLAGS.seconds_per_sequence, flags.FLAGS.seconds_per_sequence,
flags.FLAGS.hop_between_sequences) flags.FLAGS.hop_between_sequences,
flags.FLAGS.examples_for_context)
if __name__ == "__main__": if __name__ == "__main__":
flags.DEFINE_string("path_to_download_data", flags.DEFINE_string("path_to_download_data",
...@@ -375,4 +541,8 @@ if __name__ == "__main__": ...@@ -375,4 +541,8 @@ if __name__ == "__main__":
10, 10,
"The hop between sequences. If less than " "The hop between sequences. If less than "
"seconds_per_sequence, will overlap.") "seconds_per_sequence, will overlap.")
flags.DEFINE_boolean("examples_for_context",
False,
"Whether to generate examples instead of sequence examples. "
"If true, will generate tf.Example objects for use in Context R-CNN.")
app.run(main) app.run(main)
...@@ -152,7 +152,7 @@ def load_labelmap(path): ...@@ -152,7 +152,7 @@ def load_labelmap(path):
Returns: Returns:
a StringIntLabelMapProto a StringIntLabelMapProto
""" """
with tf.io.gfile.GFile(path, 'r') as fid: with tf.io.gfile.GFile(path, 'rb') as fid:
label_map_string = fid.read() label_map_string = fid.read()
label_map = string_int_label_map_pb2.StringIntLabelMap() label_map = string_int_label_map_pb2.StringIntLabelMap()
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment