Commit 2db4cc8f authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Make timestamp optional in `generate_embedding_data.py`

PiperOrigin-RevId: 324826632
parent a811a3b7
...@@ -72,6 +72,67 @@ def drop_keys(key_value_tuple): ...@@ -72,6 +72,67 @@ def drop_keys(key_value_tuple):
return key_value_tuple[1] return key_value_tuple[1]
def get_date_captured(example):
date_captured = datetime.datetime.strptime(
six.ensure_str(
example.features.feature['image/date_captured'].bytes_list.value[0]),
'%Y-%m-%d %H:%M:%S')
return date_captured
def embed_date_captured(date_captured):
"""Encodes the datetime of the image."""
embedded_date_captured = []
month_max = 12.0
day_max = 31.0
hour_max = 24.0
minute_max = 60.0
min_year = 1990.0
max_year = 2030.0
year = (date_captured.year - min_year) / float(max_year - min_year)
embedded_date_captured.append(year)
month = (date_captured.month - 1) / month_max
embedded_date_captured.append(month)
day = (date_captured.day - 1) / day_max
embedded_date_captured.append(day)
hour = date_captured.hour / hour_max
embedded_date_captured.append(hour)
minute = date_captured.minute / minute_max
embedded_date_captured.append(minute)
return np.asarray(embedded_date_captured)
def embed_position_and_size(box):
"""Encodes the bounding box of the object of interest."""
ymin = box[0]
xmin = box[1]
ymax = box[2]
xmax = box[3]
w = xmax - xmin
h = ymax - ymin
x = xmin + w / 2.0
y = ymin + h / 2.0
return np.asarray([x, y, w, h])
def get_bb_embedding(detection_features, detection_boxes, detection_scores,
index):
embedding = detection_features[0][index]
pooled_embedding = np.mean(np.mean(embedding, axis=1), axis=0)
box = detection_boxes[0][index]
position_embedding = embed_position_and_size(box)
score = detection_scores[0][index]
return np.concatenate((pooled_embedding, position_embedding)), score
class GenerateEmbeddingDataFn(beam.DoFn): class GenerateEmbeddingDataFn(beam.DoFn):
"""Generates embedding data for camera trap images. """Generates embedding data for camera trap images.
...@@ -112,68 +173,18 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -112,68 +173,18 @@ class GenerateEmbeddingDataFn(beam.DoFn):
def _run_inference_and_generate_embedding(self, tfexample_key_value): def _run_inference_and_generate_embedding(self, tfexample_key_value):
key, tfexample = tfexample_key_value key, tfexample = tfexample_key_value
input_example = tf.train.Example.FromString(tfexample) input_example = tf.train.Example.FromString(tfexample)
# Convert date_captured datetime string to unix time integer and store example = tf.train.Example()
example.CopyFrom(input_example)
def get_date_captured(example):
date_captured = datetime.datetime.strptime(
six.ensure_str(
example.features.feature[
'image/date_captured'].bytes_list.value[0]),
'%Y-%m-%d %H:%M:%S')
return date_captured
try: try:
date_captured = get_date_captured(input_example) date_captured = get_date_captured(input_example)
unix_time = ((date_captured -
datetime.datetime.fromtimestamp(0)).total_seconds())
example.features.feature['image/unix_time'].float_list.value.extend(
[unix_time])
temporal_embedding = embed_date_captured(date_captured)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
# we require date_captured to be available for all images pass
return []
def embed_date_captured(date_captured):
"""Encodes the datetime of the image."""
embedded_date_captured = []
month_max = 12.0
day_max = 31.0
hour_max = 24.0
minute_max = 60.0
min_year = 1990.0
max_year = 2030.0
year = (date_captured.year-min_year)/float(max_year-min_year)
embedded_date_captured.append(year)
month = (date_captured.month-1)/month_max
embedded_date_captured.append(month)
day = (date_captured.day-1)/day_max
embedded_date_captured.append(day)
hour = date_captured.hour/hour_max
embedded_date_captured.append(hour)
minute = date_captured.minute/minute_max
embedded_date_captured.append(minute)
return np.asarray(embedded_date_captured)
def embed_position_and_size(box):
"""Encodes the bounding box of the object of interest."""
ymin = box[0]
xmin = box[1]
ymax = box[2]
xmax = box[3]
w = xmax - xmin
h = ymax - ymin
x = xmin + w / 2.0
y = ymin + h / 2.0
return np.asarray([x, y, w, h])
unix_time = (
(date_captured - datetime.datetime.fromtimestamp(0)).total_seconds())
example = tf.train.Example()
example.CopyFrom(input_example)
example.features.feature['image/unix_time'].float_list.value.extend(
[unix_time])
detections = self._detect_fn.signatures['serving_default']( detections = self._detect_fn.signatures['serving_default'](
(tf.expand_dims(tf.convert_to_tensor(tfexample), 0))) (tf.expand_dims(tf.convert_to_tensor(tfexample), 0)))
...@@ -188,25 +199,12 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -188,25 +199,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
detection_features = np.asarray(detection_features) detection_features = np.asarray(detection_features)
def get_bb_embedding(detection_features, detection_boxes, detection_scores,
index):
embedding = detection_features[0][index]
pooled_embedding = np.mean(np.mean(embedding, axis=1), axis=0)
box = detection_boxes[0][index]
position_embedding = embed_position_and_size(box)
score = detection_scores[0][index]
return np.concatenate((pooled_embedding, position_embedding)), score
temporal_embedding = embed_date_captured(date_captured)
embedding_count = 0 embedding_count = 0
for index in range(min(num_detections, self._top_k_embedding_count)): for index in range(min(num_detections, self._top_k_embedding_count)):
bb_embedding, score = get_bb_embedding( bb_embedding, score = get_bb_embedding(
detection_features, detection_boxes, detection_scores, index) detection_features, detection_boxes, detection_scores, index)
embed_all.extend(bb_embedding) embed_all.extend(bb_embedding)
embed_all.extend(temporal_embedding) if temporal_embedding is not None: embed_all.extend(temporal_embedding)
score_all.append(score) score_all.append(score)
embedding_count += 1 embedding_count += 1
...@@ -216,7 +214,7 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -216,7 +214,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
bb_embedding, score = get_bb_embedding( bb_embedding, score = get_bb_embedding(
detection_features, detection_boxes, detection_scores, index) detection_features, detection_boxes, detection_scores, index)
embed_all.extend(bb_embedding) embed_all.extend(bb_embedding)
embed_all.extend(temporal_embedding) if temporal_embedding is not None: embed_all.extend(temporal_embedding)
score_all.append(score) score_all.append(score)
embedding_count += 1 embedding_count += 1
...@@ -224,7 +222,7 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -224,7 +222,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
bb_embedding, score = get_bb_embedding( bb_embedding, score = get_bb_embedding(
detection_features, detection_boxes, detection_scores, 0) detection_features, detection_boxes, detection_scores, 0)
embed_all.extend(bb_embedding) embed_all.extend(bb_embedding)
embed_all.extend(temporal_embedding) if temporal_embedding is not None: embed_all.extend(temporal_embedding)
score_all.append(score) score_all.append(score)
# Takes max in case embedding_count is 0. # Takes max in case embedding_count is 0.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment