Commit 98d9f3b8 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

lint more, clean up, push

parent ff91b5c4
......@@ -50,7 +50,7 @@ Note that the number of videos changes in the data set over time, so it will
likely be necessary to change the expected number of examples.
The argument video_path_format_string expects a value as such:
"/path/to/videos/{0}"
'/path/to/videos/{0}'
"""
......@@ -78,28 +78,28 @@ from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
POSSIBLE_TIMESTAMPS = range(902, 1798)
ANNOTATION_URL = "https://research.google.com/ava/download/ava_v2.2.zip"
ANNOTATION_URL = 'https://research.google.com/ava/download/ava_v2.2.zip'
SECONDS_TO_MILLI = 1000
FILEPATTERN = "ava_actions_%s_1fps_rgb"
FILEPATTERN = 'ava_actions_%s_1fps_rgb'
SPLITS = {
"train": {
"shards": 1000,
"examples": 862663,
"csv": '',
"excluded-csv": ''
'train': {
'shards': 1000,
'examples': 862663,
'csv': '',
'excluded-csv': ''
},
"val": {
"shards": 100,
"examples": 243029,
"csv": '',
"excluded-csv": ''
'val': {
'shards': 100,
'examples': 243029,
'csv': '',
'excluded-csv': ''
},
#Test doesn't have ground truth, so TF Records can't be created
"test": {
"shards": 100,
"examples": 0,
"csv": '',
"excluded-csv": ''
'test': {
'shards': 100,
'examples': 0,
'csv': '',
'excluded-csv': ''
}
}
......@@ -115,12 +115,12 @@ class Ava(object):
def __init__(self, path_to_output_dir, path_to_data_download):
if not path_to_output_dir:
raise ValueError("You must supply the path to the data directory.")
raise ValueError('You must supply the path to the data directory.')
self.path_to_data_download = path_to_data_download
self.path_to_output_dir = path_to_output_dir
def generate_and_write_records(self,
splits_to_process="train,val,test",
splits_to_process='train,val,test',
video_path_format_string=None,
seconds_per_sequence=10,
hop_between_sequences=10,
......@@ -133,9 +133,9 @@ class Ava(object):
original data files can be deleted.
Args:
splits_to_process: csv string of which splits to process. Allows providing
a custom CSV with the CSV flag. The original data is still downloaded
to generate the label_map.
splits_to_process: csv string of which splits to process. Allows
providing a custom CSV with the CSV flag. The original data is still
downloaded to generate the label_map.
video_path_format_string: The format string for the path to local files.
seconds_per_sequence: The length of each sequence, in seconds.
hop_between_sequences: The gap between the centers of
......@@ -145,32 +145,33 @@ class Ava(object):
if examples_for_context:
example_function = self._generate_examples
logging.info("Downloading data.")
logging.info('Downloading data.')
download_output = self._download_data()
for key in splits_to_process.split(","):
logging.info("Generating examples for split: %s", key)
for key in splits_to_process.split(','):
logging.info('Generating examples for split: %s', key)
all_metadata = list(example_function(
download_output[0][key][0], download_output[0][key][1],
download_output[1], seconds_per_sequence, hop_between_sequences,
video_path_format_string))
logging.info("An example of the metadata: ")
logging.info('An example of the metadata: ')
logging.info(all_metadata[0])
random.seed(47)
random.shuffle(all_metadata)
shards = SPLITS[key]["shards"]
shards = SPLITS[key]['shards']
shard_names = [os.path.join(
self.path_to_output_dir, FILEPATTERN % key + "-%05d-of-%05d" % (
self.path_to_output_dir, FILEPATTERN % key + '-%05d-of-%05d' % (
i, shards)) for i in range(shards)]
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
writers = [tf.io.TFRecordWriter(shard) for shard in shard_names]
with _close_on_exit(writers) as writers:
for i, seq_ex in enumerate(all_metadata):
writers[i % len(writers)].write(seq_ex.SerializeToString())
logging.info("Data extraction complete.")
logging.info('Data extraction complete.')
def _generate_sequence_examples(self, annotation_file, excluded_file, label_map,
seconds_per_sequence, hop_between_sequences,
def _generate_sequence_examples(self, annotation_file, excluded_file,
label_map, seconds_per_sequence,
hop_between_sequences,
video_path_format_string):
"""For each row in the annotation CSV, generates the corresponding examples.
"""For each row in the annotation CSV, generates corresponding examples.
When iterating through frames for a single sequence example, skips over
excluded frames. When moving to the next sequence example, also skips over
......@@ -189,32 +190,32 @@ class Ava(object):
Yields:
Each prepared tf.SequenceExample of metadata also containing video frames
"""
fieldnames = ["id", "timestamp_seconds", "xmin", "ymin", "xmax", "ymax",
"action_label"]
fieldnames = ['id', 'timestamp_seconds', 'xmin', 'ymin', 'xmax', 'ymax',
'action_label']
frame_excluded = {}
# create a sparse, nested map of videos and frame indices.
with open(excluded_file, "r") as excluded:
with open(excluded_file, 'r') as excluded:
reader = csv.reader(excluded)
for row in reader:
frame_excluded[(row[0], int(float(row[1])))] = True
with open(annotation_file, "r") as annotations:
with open(annotation_file, 'r') as annotations:
reader = csv.DictReader(annotations, fieldnames)
frame_annotations = collections.defaultdict(list)
ids = set()
# aggreggate by video and timestamp:
for row in reader:
ids.add(row["id"])
key = (row["id"], int(float(row["timestamp_seconds"])))
ids.add(row['id'])
key = (row['id'], int(float(row['timestamp_seconds'])))
frame_annotations[key].append(row)
# for each video, find aggregates near each sampled frame.:
logging.info("Generating metadata...")
logging.info('Generating metadata...')
media_num = 1
for media_id in ids:
logging.info("%d/%d, ignore warnings.\n" % (media_num, len(ids)))
logging.info('%d/%d, ignore warnings.\n' % (media_num, len(ids)))
media_num += 1
filepath = glob.glob(
video_path_format_string.format(media_id) + "*")[0]
video_path_format_string.format(media_id) + '*')[0]
cur_vid = cv2.VideoCapture(filepath)
width = cur_vid.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cur_vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
......@@ -237,7 +238,7 @@ class Ava(object):
if (media_id, windowed_timestamp) in frame_excluded:
end_time += 1
windowed_timestamp += 1
logging.info("Ignoring and skipping excluded frame.")
logging.info('Ignoring and skipping excluded frame.')
continue
cur_vid.set(cv2.CAP_PROP_POS_MSEC,
......@@ -247,7 +248,7 @@ class Ava(object):
bufstring = buffer.tostring()
total_images.append(bufstring)
source_id = str(windowed_timestamp) + "_" + media_id
source_id = str(windowed_timestamp) + '_' + media_id
total_source_ids.append(source_id)
total_is_annotated.append(1)
......@@ -256,14 +257,14 @@ class Ava(object):
label_strings = []
confidences = []
for row in frame_annotations[(media_id, windowed_timestamp)]:
if len(row) > 2 and int(row["action_label"]) in label_map:
boxes.append([float(row["ymin"]), float(row["xmin"]),
float(row["ymax"]), float(row["xmax"])])
labels.append(int(row["action_label"]))
label_strings.append(label_map[int(row["action_label"])])
if len(row) > 2 and int(row['action_label']) in label_map:
boxes.append([float(row['ymin']), float(row['xmin']),
float(row['ymax']), float(row['xmax'])])
labels.append(int(row['action_label']))
label_strings.append(label_map[int(row['action_label'])])
confidences.append(1)
else:
logging.warning("Unknown label: %s", row["action_label"])
logging.warning('Unknown label: %s', row['action_label'])
total_boxes.append(boxes)
total_labels.append(labels)
......@@ -272,9 +273,10 @@ class Ava(object):
windowed_timestamp += 1
if len(total_boxes) > 0:
yield seq_example_util.make_sequence_example("AVA", media_id, total_images,
int(height), int(width), 'jpeg', total_source_ids, None, total_is_annotated,
total_boxes, total_label_strings, use_strs_for_source_id=True)
yield seq_example_util.make_sequence_example(
'AVA', media_id, total_images, int(height), int(width), 'jpeg',
total_source_ids, None, total_is_annotated, total_boxes,
total_label_strings, use_strs_for_source_id=True)
#Move middle_time_frame, skipping excluded frames
frames_mv = 0
......@@ -307,33 +309,33 @@ class Ava(object):
Yields:
Each prepared tf.Example of metadata also containing video frames
"""
fieldnames = ["id", "timestamp_seconds", "xmin", "ymin", "xmax", "ymax",
"action_label"]
fieldnames = ['id', 'timestamp_seconds', 'xmin', 'ymin', 'xmax', 'ymax',
'action_label']
frame_excluded = {}
# create a sparse, nested map of videos and frame indices.
with open(excluded_file, "r") as excluded:
with open(excluded_file, 'r') as excluded:
reader = csv.reader(excluded)
for row in reader:
frame_excluded[(row[0], int(float(row[1])))] = True
with open(annotation_file, "r") as annotations:
with open(annotation_file, 'r') as annotations:
reader = csv.DictReader(annotations, fieldnames)
frame_annotations = collections.defaultdict(list)
ids = set()
# aggreggate by video and timestamp:
for row in reader:
ids.add(row["id"])
key = (row["id"], int(float(row["timestamp_seconds"])))
ids.add(row['id'])
key = (row['id'], int(float(row['timestamp_seconds'])))
frame_annotations[key].append(row)
# for each video, find aggreggates near each sampled frame.:
logging.info("Generating metadata...")
logging.info('Generating metadata...')
media_num = 1
for media_id in ids:
logging.info("%d/%d, ignore warnings.\n" % (media_num, len(ids)))
logging.info('%d/%d, ignore warnings.\n' % (media_num, len(ids)))
media_num += 1
filepath = glob.glob(
video_path_format_string.format(media_id) + "*")[0]
filename = filepath.split("/")[-1]
video_path_format_string.format(media_id) + '*')[0]
filename = filepath.split('/')[-1]
cur_vid = cv2.VideoCapture(filepath)
width = cur_vid.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cur_vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
......@@ -348,7 +350,7 @@ class Ava(object):
cur_frame_num = 0
while middle_frame_time < POSSIBLE_TIMESTAMPS[-1]:
cur_vid.set(cv2.CAP_PROP_POS_MSEC,
(middle_frame_time) * SECONDS_TO_MILLI)
middle_frame_time * SECONDS_TO_MILLI)
success, image = cur_vid.read()
success, buffer = cv2.imencode('.jpg', image)
......@@ -356,11 +358,11 @@ class Ava(object):
if (media_id, middle_frame_time) in frame_excluded:
middle_frame_time += 1
logging.info("Ignoring and skipping excluded frame.")
logging.info('Ignoring and skipping excluded frame.')
continue
cur_frame_num += 1
source_id = str(middle_frame_time) + "_" + media_id
source_id = str(middle_frame_time) + '_' + media_id
xmins = []
xmaxs = []
......@@ -371,18 +373,18 @@ class Ava(object):
label_strings = []
confidences = []
for row in frame_annotations[(media_id, middle_frame_time)]:
if len(row) > 2 and int(row["action_label"]) in label_map:
xmins.append(float(row["xmin"]))
xmaxs.append(float(row["xmax"]))
ymins.append(float(row["ymin"]))
ymaxs.append(float(row["ymax"]))
if len(row) > 2 and int(row['action_label']) in label_map:
xmins.append(float(row['xmin']))
xmaxs.append(float(row['xmax']))
ymins.append(float(row['ymin']))
ymaxs.append(float(row['ymax']))
areas.append(float((xmaxs[-1] - xmins[-1]) *
(ymaxs[-1] - ymins[-1])) / 2)
labels.append(int(row["action_label"]))
label_strings.append(label_map[int(row["action_label"])])
labels.append(int(row['action_label']))
label_strings.append(label_map[int(row['action_label'])])
confidences.append(1)
else:
logging.warning("Unknown label: %s", row["action_label"])
logging.warning('Unknown label: %s', row['action_label'])
middle_frame_time += 1/3
if abs(middle_frame_time - round(middle_frame_time) < 0.0001):
......@@ -390,7 +392,7 @@ class Ava(object):
key = hashlib.sha256(bufstring).hexdigest()
date_captured_feature = (
"2020-06-17 00:%02d:%02d" % ((middle_frame_time - 900)*3 // 60,
'2020-06-17 00:%02d:%02d' % ((middle_frame_time - 900)*3 // 60,
(middle_frame_time - 900)*3 % 60))
context_feature_dict = {
'image/height':
......@@ -400,9 +402,9 @@ class Ava(object):
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/source_id':
dataset_util.bytes_feature(source_id.encode("utf8")),
dataset_util.bytes_feature(source_id.encode('utf8')),
'image/filename':
dataset_util.bytes_feature(source_id.encode("utf8")),
dataset_util.bytes_feature(source_id.encode('utf8')),
'image/encoded':
dataset_util.bytes_feature(bufstring),
'image/key/sha256':
......@@ -424,7 +426,8 @@ class Ava(object):
'image/location':
dataset_util.bytes_feature(media_id.encode('utf8')),
'image/date_captured':
dataset_util.bytes_feature(date_captured_feature.encode('utf8')),
dataset_util.bytes_feature(
date_captured_feature.encode('utf8')),
'image/seq_num_frames':
dataset_util.int64_feature(total_non_excluded),
'image/seq_frame_num':
......@@ -444,28 +447,28 @@ class Ava(object):
urlretrieve = urllib.request.urlretrieve
else:
urlretrieve = urllib.request.urlretrieve
logging.info("Creating data directory.")
logging.info('Creating data directory.')
tf.io.gfile.makedirs(self.path_to_data_download)
logging.info("Downloading annotations.")
logging.info('Downloading annotations.')
paths = {}
zip_path = os.path.join(self.path_to_data_download,
ANNOTATION_URL.split("/")[-1])
ANNOTATION_URL.split('/')[-1])
urlretrieve(ANNOTATION_URL, zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.path_to_data_download)
for split in ["train", "test", "val"]:
for split in ['train', 'test', 'val']:
csv_path = os.path.join(self.path_to_data_download,
"ava_%s_v2.2.csv" % split)
excl_name = "ava_%s_excluded_timestamps_v2.2.csv" % split
'ava_%s_v2.2.csv' % split)
excl_name = 'ava_%s_excluded_timestamps_v2.2.csv' % split
excluded_csv_path = os.path.join(self.path_to_data_download, excl_name)
SPLITS[split]["csv"] = csv_path
SPLITS[split]["excluded-csv"] = excluded_csv_path
SPLITS[split]['csv'] = csv_path
SPLITS[split]['excluded-csv'] = excluded_csv_path
paths[split] = (csv_path, excluded_csv_path)
label_map = self.get_label_map(os.path.join(
self.path_to_data_download,
"ava_action_list_v2.2_for_activitynet_2019.pbtxt"))
'ava_action_list_v2.2_for_activitynet_2019.pbtxt'))
return paths, label_map
def get_label_map(self, path):
......@@ -487,7 +490,7 @@ def _close_on_exit(writers):
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
raise app.UsageError('Too many command-line arguments.')
Ava(flags.FLAGS.path_to_output_dir,
flags.FLAGS.path_to_download_data).generate_and_write_records(
flags.FLAGS.splits_to_process,
......@@ -496,34 +499,34 @@ def main(argv):
flags.FLAGS.hop_between_sequences,
flags.FLAGS.examples_for_context)
if __name__ == "__main__":
flags.DEFINE_string("path_to_download_data",
"",
"Path to directory to download data to.")
flags.DEFINE_string("path_to_output_dir",
"",
"Path to directory to write data to.")
flags.DEFINE_string("splits_to_process",
"train,val",
"Process these splits. Useful for custom data splits.")
flags.DEFINE_string("video_path_format_string",
if __name__ == '__main__':
flags.DEFINE_string('path_to_download_data',
'',
'Path to directory to download data to.')
flags.DEFINE_string('path_to_output_dir',
'',
'Path to directory to write data to.')
flags.DEFINE_string('splits_to_process',
'train,val',
'Process these splits. Useful for custom data splits.')
flags.DEFINE_string('video_path_format_string',
None,
"The format string for the path to local video files. "
"Uses the Python string.format() syntax with possible "
"arguments of {video}, {start}, {end}, {label_name}, and "
"{split}, corresponding to columns of the data csvs.")
flags.DEFINE_integer("seconds_per_sequence",
'The format string for the path to local video files. '
'Uses the Python string.format() syntax with possible '
'arguments of {video}, {start}, {end}, {label_name}, and '
'{split}, corresponding to columns of the data csvs.')
flags.DEFINE_integer('seconds_per_sequence',
10,
"The number of seconds per example in each example. Always"
"1 when examples_for_context is True.")
flags.DEFINE_integer("hop_between_sequences",
'The number of seconds per example in each example.'
'Always 1 when examples_for_context is True.')
flags.DEFINE_integer('hop_between_sequences',
10,
"The hop between sequences. If less than "
"seconds_per_sequence, will overlap. Always 1 when "
"examples_for_context is True.")
flags.DEFINE_boolean("examples_for_context",
'The hop between sequences. If less than '
'seconds_per_sequence, will overlap. Always 1 when '
'examples_for_context is True.')
flags.DEFINE_boolean('examples_for_context',
False,
"Whether to generate examples instead of sequence examples. "
"If true, will generate tf.Example objects "
"for use in Context R-CNN.")
'Whether to generate examples instead of sequence '
'examples. If true, will generate tf.Example objects '
'for use in Context R-CNN.')
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment