Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6faf56a6
Commit
6faf56a6
authored
May 03, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
May 03, 2021
Browse files
Make apply_nms configurable.
PiperOrigin-RevId: 371735494
parent
a90e36a4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
17 deletions
+47
-17
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+1
-0
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+1
-0
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+2
-2
official/vision/beta/modeling/maskrcnn_model.py
official/vision/beta/modeling/maskrcnn_model.py
+14
-4
official/vision/beta/modeling/retinanet_model.py
official/vision/beta/modeling/retinanet_model.py
+13
-4
official/vision/beta/serving/detection.py
official/vision/beta/serving/detection.py
+16
-7
No files found.
official/vision/beta/configs/maskrcnn.py
View file @
6faf56a6
...
@@ -144,6 +144,7 @@ class ROIAligner(hyperparams.Config):
...
@@ -144,6 +144,7 @@ class ROIAligner(hyperparams.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
DetectionGenerator
(
hyperparams
.
Config
):
class
DetectionGenerator
(
hyperparams
.
Config
):
apply_nms
:
bool
=
True
pre_nms_top_k
:
int
=
5000
pre_nms_top_k
:
int
=
5000
pre_nms_score_threshold
:
float
=
0.05
pre_nms_score_threshold
:
float
=
0.05
nms_iou_threshold
:
float
=
0.5
nms_iou_threshold
:
float
=
0.5
...
...
official/vision/beta/configs/retinanet.py
View file @
6faf56a6
...
@@ -106,6 +106,7 @@ class RetinaNetHead(hyperparams.Config):
...
@@ -106,6 +106,7 @@ class RetinaNetHead(hyperparams.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
DetectionGenerator
(
hyperparams
.
Config
):
class
DetectionGenerator
(
hyperparams
.
Config
):
apply_nms
:
bool
=
True
pre_nms_top_k
:
int
=
5000
pre_nms_top_k
:
int
=
5000
pre_nms_score_threshold
:
float
=
0.05
pre_nms_score_threshold
:
float
=
0.05
nms_iou_threshold
:
float
=
0.5
nms_iou_threshold
:
float
=
0.5
...
...
official/vision/beta/modeling/factory.py
View file @
6faf56a6
...
@@ -160,7 +160,7 @@ def build_maskrcnn(
...
@@ -160,7 +160,7 @@ def build_maskrcnn(
sample_offset
=
roi_aligner_config
.
sample_offset
)
sample_offset
=
roi_aligner_config
.
sample_offset
)
detection_generator_obj
=
detection_generator
.
DetectionGenerator
(
detection_generator_obj
=
detection_generator
.
DetectionGenerator
(
apply_nms
=
True
,
apply_nms
=
generator_config
.
apply_nms
,
pre_nms_top_k
=
generator_config
.
pre_nms_top_k
,
pre_nms_top_k
=
generator_config
.
pre_nms_top_k
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
...
@@ -255,7 +255,7 @@ def build_retinanet(
...
@@ -255,7 +255,7 @@ def build_retinanet(
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
)
detection_generator_obj
=
detection_generator
.
MultilevelDetectionGenerator
(
detection_generator_obj
=
detection_generator
.
MultilevelDetectionGenerator
(
apply_nms
=
True
,
apply_nms
=
generator_config
.
apply_nms
,
pre_nms_top_k
=
generator_config
.
pre_nms_top_k
,
pre_nms_top_k
=
generator_config
.
pre_nms_top_k
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
pre_nms_score_threshold
=
generator_config
.
pre_nms_score_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
nms_iou_threshold
=
generator_config
.
nms_iou_threshold
,
...
...
official/vision/beta/modeling/maskrcnn_model.py
View file @
6faf56a6
...
@@ -216,11 +216,21 @@ class MaskRCNNModel(tf.keras.Model):
...
@@ -216,11 +216,21 @@ class MaskRCNNModel(tf.keras.Model):
regression_weights
,
regression_weights
,
bbox_per_class
=
(
not
self
.
_config_dict
[
'class_agnostic_bbox_pred'
]))
bbox_per_class
=
(
not
self
.
_config_dict
[
'class_agnostic_bbox_pred'
]))
model_outputs
.
update
({
model_outputs
.
update
({
'detection_boxes'
:
detections
[
'detection_boxes'
],
'cls_outputs'
:
class_outputs
,
'detection_scores'
:
detections
[
'detection_scores'
],
'box_outputs'
:
box_outputs
,
'detection_classes'
:
detections
[
'detection_classes'
],
'num_detections'
:
detections
[
'num_detections'
],
})
})
if
self
.
detection_generator
.
get_config
()[
'apply_nms'
]:
model_outputs
.
update
({
'detection_boxes'
:
detections
[
'detection_boxes'
],
'detection_scores'
:
detections
[
'detection_scores'
],
'detection_classes'
:
detections
[
'detection_classes'
],
'num_detections'
:
detections
[
'num_detections'
]
})
else
:
model_outputs
.
update
({
'decoded_boxes'
:
detections
[
'decoded_boxes'
],
'decoded_box_scores'
:
detections
[
'decoded_box_scores'
]
})
if
not
self
.
_include_mask
:
if
not
self
.
_include_mask
:
return
model_outputs
return
model_outputs
...
...
official/vision/beta/modeling/retinanet_model.py
View file @
6faf56a6
...
@@ -148,13 +148,22 @@ class RetinaNetModel(tf.keras.Model):
...
@@ -148,13 +148,22 @@ class RetinaNetModel(tf.keras.Model):
final_results
=
self
.
detection_generator
(
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
outputs
=
{
outputs
=
{
'detection_boxes'
:
final_results
[
'detection_boxes'
],
'detection_scores'
:
final_results
[
'detection_scores'
],
'detection_classes'
:
final_results
[
'detection_classes'
],
'num_detections'
:
final_results
[
'num_detections'
],
'cls_outputs'
:
raw_scores
,
'cls_outputs'
:
raw_scores
,
'box_outputs'
:
raw_boxes
,
'box_outputs'
:
raw_boxes
,
}
}
if
self
.
detection_generator
.
get_config
()[
'apply_nms'
]:
outputs
.
update
({
'detection_boxes'
:
final_results
[
'detection_boxes'
],
'detection_scores'
:
final_results
[
'detection_scores'
],
'detection_classes'
:
final_results
[
'detection_classes'
],
'num_detections'
:
final_results
[
'num_detections'
]
})
else
:
outputs
.
update
({
'decoded_boxes'
:
final_results
[
'decoded_boxes'
],
'decoded_box_scores'
:
final_results
[
'decoded_box_scores'
]
})
if
raw_attributes
:
if
raw_attributes
:
outputs
.
update
({
outputs
.
update
({
'att_outputs'
:
raw_attributes
,
'att_outputs'
:
raw_attributes
,
...
...
official/vision/beta/serving/detection.py
View file @
6faf56a6
...
@@ -126,14 +126,23 @@ class DetectionModule(export_base.ExportModule):
...
@@ -126,14 +126,23 @@ class DetectionModule(export_base.ExportModule):
anchor_boxes
=
anchor_boxes
,
anchor_boxes
=
anchor_boxes
,
training
=
False
)
training
=
False
)
final_outputs
=
{
if
self
.
params
.
task
.
model
.
detection_generator
.
apply_nms
:
'detection_boxes'
:
detections
[
'detection_boxes'
],
final_outputs
=
{
'detection_scores'
:
detections
[
'detection_scores'
],
'detection_boxes'
:
detections
[
'detection_boxes'
],
'detection_classes'
:
detections
[
'detection_classes'
],
'detection_scores'
:
detections
[
'detection_scores'
],
'num_detections'
:
detections
[
'num_detections'
],
'detection_classes'
:
detections
[
'detection_classes'
],
'image_info'
:
image_info
'num_detections'
:
detections
[
'num_detections'
]
}
}
else
:
final_outputs
=
{
'decoded_boxes'
:
detections
[
'decoded_boxes'
],
'decoded_box_scores'
:
detections
[
'decoded_box_scores'
],
'cls_outputs'
:
detections
[
'cls_outputs'
],
'box_outputs'
:
detections
[
'box_outputs'
]
}
if
'detection_masks'
in
detections
.
keys
():
if
'detection_masks'
in
detections
.
keys
():
final_outputs
[
'detection_masks'
]
=
detections
[
'detection_masks'
]
final_outputs
[
'detection_masks'
]
=
detections
[
'detection_masks'
]
final_outputs
.
update
({
'image_info'
:
image_info
})
return
final_outputs
return
final_outputs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment