Unverified Commit 603d702d authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

replaced for loop with `tf.map_fn`

parent cfc9f1f7
...@@ -240,33 +240,24 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -240,33 +240,24 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf.argmax(batched_segmentation_masks, axis=-1), tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1) dtype=tf.float32), axis=-1)
batch_size, _, _ = batched_boxes.get_shape().as_list() panoptic_masks = tf.map_fn(
fn=lambda x: self._generate_panoptic_masks(
if batch_size is None: x[0], x[1], x[2], x[3], x[4]),
batch_size = tf.shape(batched_boxes)[0] elems=(
batched_boxes,
category_mask = [] batched_scores,
instance_mask = [] batched_classes,
batched_detections_masks,
for idx in range(batch_size): batched_segmentation_masks),
results = self._generate_panoptic_masks( fn_output_signature={
boxes=batched_boxes[idx], 'category_mask': tf.float32,
scores=batched_scores[idx], 'instance_mask': tf.float32
classes=batched_classes[idx], })
detections_masks=batched_detections_masks[idx],
segmentation_mask=batched_segmentation_masks[idx]) for k, v in panoptic_masks.items():
panoptic_masks[k] = tf.cast(v, dtype=tf.int32)
category_mask.append(results['category_mask'])
instance_mask.append(results['instance_mask']) return panoptic_masks
category_mask = tf.stack(category_mask, axis=0)
instance_mask = tf.stack(instance_mask, axis=0)
outputs = {
'category_mask': tf.cast(category_mask, dtype=tf.int32),
'instance_mask': tf.cast(instance_mask, dtype=tf.int32)
}
return outputs
def get_config(self): def get_config(self):
return self._config_dict return self._config_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment