Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
603d702d
Unverified
Commit
603d702d
authored
Sep 12, 2021
by
srihari-humbarwadi
Browse files
replaced for loop with `tf.map_fn`
parent
cfc9f1f7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
27 deletions
+18
-27
official/vision/beta/projects/panoptic_maskrcnn/modeling/layers/panoptic_segmentation_generator.py
...skrcnn/modeling/layers/panoptic_segmentation_generator.py
+18
-27
No files found.
official/vision/beta/projects/panoptic_maskrcnn/modeling/layers/panoptic_segmentation_generator.py
View file @
603d702d
...
...
@@ -240,33 +240,24 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf
.
argmax
(
batched_segmentation_masks
,
axis
=-
1
),
dtype
=
tf
.
float32
),
axis
=-
1
)
batch_size
,
_
,
_
=
batched_boxes
.
get_shape
().
as_list
()
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
batched_boxes
)[
0
]
category_mask
=
[]
instance_mask
=
[]
for
idx
in
range
(
batch_size
):
results
=
self
.
_generate_panoptic_masks
(
boxes
=
batched_boxes
[
idx
],
scores
=
batched_scores
[
idx
],
classes
=
batched_classes
[
idx
],
detections_masks
=
batched_detections_masks
[
idx
],
segmentation_mask
=
batched_segmentation_masks
[
idx
])
category_mask
.
append
(
results
[
'category_mask'
])
instance_mask
.
append
(
results
[
'instance_mask'
])
category_mask
=
tf
.
stack
(
category_mask
,
axis
=
0
)
instance_mask
=
tf
.
stack
(
instance_mask
,
axis
=
0
)
outputs
=
{
'category_mask'
:
tf
.
cast
(
category_mask
,
dtype
=
tf
.
int32
),
'instance_mask'
:
tf
.
cast
(
instance_mask
,
dtype
=
tf
.
int32
)
}
return
outputs
panoptic_masks
=
tf
.
map_fn
(
fn
=
lambda
x
:
self
.
_generate_panoptic_masks
(
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
],
x
[
4
]),
elems
=
(
batched_boxes
,
batched_scores
,
batched_classes
,
batched_detections_masks
,
batched_segmentation_masks
),
fn_output_signature
=
{
'category_mask'
:
tf
.
float32
,
'instance_mask'
:
tf
.
float32
})
for
k
,
v
in
panoptic_masks
.
items
():
panoptic_masks
[
k
]
=
tf
.
cast
(
v
,
dtype
=
tf
.
int32
)
return
panoptic_masks
def
get_config
(
self
):
return
self
.
_config_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment