Unverified Commit 603d702d authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

replaced for loop with `tf.map_fn`

parent cfc9f1f7
......@@ -240,33 +240,24 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1)
batch_size, _, _ = batched_boxes.get_shape().as_list()
if batch_size is None:
batch_size = tf.shape(batched_boxes)[0]
category_mask = []
instance_mask = []
for idx in range(batch_size):
results = self._generate_panoptic_masks(
boxes=batched_boxes[idx],
scores=batched_scores[idx],
classes=batched_classes[idx],
detections_masks=batched_detections_masks[idx],
segmentation_mask=batched_segmentation_masks[idx])
category_mask.append(results['category_mask'])
instance_mask.append(results['instance_mask'])
category_mask = tf.stack(category_mask, axis=0)
instance_mask = tf.stack(instance_mask, axis=0)
outputs = {
'category_mask': tf.cast(category_mask, dtype=tf.int32),
'instance_mask': tf.cast(instance_mask, dtype=tf.int32)
}
return outputs
panoptic_masks = tf.map_fn(
fn=lambda x: self._generate_panoptic_masks(
x[0], x[1], x[2], x[3], x[4]),
elems=(
batched_boxes,
batched_scores,
batched_classes,
batched_detections_masks,
batched_segmentation_masks),
fn_output_signature={
'category_mask': tf.float32,
'instance_mask': tf.float32
})
for k, v in panoptic_masks.items():
panoptic_masks[k] = tf.cast(v, dtype=tf.int32)
return panoptic_masks
def get_config(self):
return self._config_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment