Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a4d9c3a0
Commit
a4d9c3a0
authored
Apr 09, 2018
by
Zhichao Lu
Committed by
pkulzc
Apr 13, 2018
Browse files
Class agnostic masks for mask_rcnn
PiperOrigin-RevId: 192132440
parent
bfd15ec1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
34 deletions
+73
-34
research/object_detection/builders/box_predictor_builder.py
research/object_detection/builders/box_predictor_builder.py
+2
-0
research/object_detection/core/box_predictor.py
research/object_detection/core/box_predictor.py
+6
-1
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
...ect_detection/meta_architectures/faster_rcnn_meta_arch.py
+34
-21
research/object_detection/meta_architectures/faster_rcnn_meta_arch_test.py
...etection/meta_architectures/faster_rcnn_meta_arch_test.py
+19
-7
research/object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py
...tion/meta_architectures/faster_rcnn_meta_arch_test_lib.py
+11
-5
research/object_detection/protos/box_predictor.proto
research/object_detection/protos/box_predictor.proto
+1
-0
No files found.
research/object_detection/builders/box_predictor_builder.py
View file @
a4d9c3a0
...
@@ -111,6 +111,8 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
...
@@ -111,6 +111,8 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
mask_rcnn_box_predictor
.
mask_prediction_num_conv_layers
),
mask_rcnn_box_predictor
.
mask_prediction_num_conv_layers
),
mask_prediction_conv_depth
=
(
mask_prediction_conv_depth
=
(
mask_rcnn_box_predictor
.
mask_prediction_conv_depth
),
mask_rcnn_box_predictor
.
mask_prediction_conv_depth
),
masks_are_class_agnostic
=
(
mask_rcnn_box_predictor
.
masks_are_class_agnostic
),
predict_keypoints
=
mask_rcnn_box_predictor
.
predict_keypoints
)
predict_keypoints
=
mask_rcnn_box_predictor
.
predict_keypoints
)
return
box_predictor_object
return
box_predictor_object
...
...
research/object_detection/core/box_predictor.py
View file @
a4d9c3a0
...
@@ -307,6 +307,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
...
@@ -307,6 +307,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
mask_width
=
14
,
mask_width
=
14
,
mask_prediction_num_conv_layers
=
2
,
mask_prediction_num_conv_layers
=
2
,
mask_prediction_conv_depth
=
256
,
mask_prediction_conv_depth
=
256
,
masks_are_class_agnostic
=
False
,
predict_keypoints
=
False
):
predict_keypoints
=
False
):
"""Constructor.
"""Constructor.
...
@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
...
@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
to 0, the depth of the convolution layers will be automatically chosen
to 0, the depth of the convolution layers will be automatically chosen
based on the number of object classes and the number of channels in the
based on the number of object classes and the number of channels in the
image features.
image features.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
predict_keypoints: Whether to predict keypoints insde detection boxes.
predict_keypoints: Whether to predict keypoints insde detection boxes.
...
@@ -357,6 +360,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
...
@@ -357,6 +360,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
self
.
_mask_width
=
mask_width
self
.
_mask_width
=
mask_width
self
.
_mask_prediction_num_conv_layers
=
mask_prediction_num_conv_layers
self
.
_mask_prediction_num_conv_layers
=
mask_prediction_num_conv_layers
self
.
_mask_prediction_conv_depth
=
mask_prediction_conv_depth
self
.
_mask_prediction_conv_depth
=
mask_prediction_conv_depth
self
.
_masks_are_class_agnostic
=
masks_are_class_agnostic
self
.
_predict_keypoints
=
predict_keypoints
self
.
_predict_keypoints
=
predict_keypoints
if
self
.
_predict_keypoints
:
if
self
.
_predict_keypoints
:
raise
ValueError
(
'Keypoint prediction is unimplemented.'
)
raise
ValueError
(
'Keypoint prediction is unimplemented.'
)
...
@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor):
...
@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor):
upsampled_features
,
upsampled_features
,
num_outputs
=
num_conv_channels
,
num_outputs
=
num_conv_channels
,
kernel_size
=
[
3
,
3
])
kernel_size
=
[
3
,
3
])
num_masks
=
1
if
self
.
_masks_are_class_agnostic
else
self
.
num_classes
mask_predictions
=
slim
.
conv2d
(
upsampled_features
,
mask_predictions
=
slim
.
conv2d
(
upsampled_features
,
num_outputs
=
self
.
num_classe
s
,
num_outputs
=
num_mask
s
,
activation_fn
=
None
,
activation_fn
=
None
,
kernel_size
=
[
3
,
3
])
kernel_size
=
[
3
,
3
])
return
tf
.
expand_dims
(
return
tf
.
expand_dims
(
...
...
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
View file @
a4d9c3a0
...
@@ -768,9 +768,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -768,9 +768,11 @@ class FasterRCNNMetaArch(model.DetectionModel):
predict_auxiliary_outputs
=
predict_auxiliary_outputs
)
predict_auxiliary_outputs
=
predict_auxiliary_outputs
)
refined_box_encodings
=
tf
.
squeeze
(
refined_box_encodings
=
tf
.
squeeze
(
box_predictions
[
box_predictor
.
BOX_ENCODINGS
],
axis
=
1
)
box_predictions
[
box_predictor
.
BOX_ENCODINGS
],
class_predictions_with_background
=
tf
.
squeeze
(
box_predictions
[
axis
=
1
,
name
=
'all_refined_box_encodings'
)
box_predictor
.
CLASS_PREDICTIONS_WITH_BACKGROUND
],
axis
=
1
)
class_predictions_with_background
=
tf
.
squeeze
(
box_predictions
[
box_predictor
.
CLASS_PREDICTIONS_WITH_BACKGROUND
],
axis
=
1
,
name
=
'all_class_predictions_with_background'
)
absolute_proposal_boxes
=
ops
.
normalized_to_image_coordinates
(
absolute_proposal_boxes
=
ops
.
normalized_to_image_coordinates
(
proposal_boxes_normalized
,
image_shape
,
self
.
_parallel_iterations
)
proposal_boxes_normalized
,
image_shape
,
self
.
_parallel_iterations
)
...
@@ -794,6 +796,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -794,6 +796,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
def
_predict_third_stage
(
self
,
prediction_dict
,
image_shapes
):
def
_predict_third_stage
(
self
,
prediction_dict
,
image_shapes
):
"""Predicts non-box, non-class outputs using refined detections.
"""Predicts non-box, non-class outputs using refined detections.
This happens after calling the post-processing stage, such that masks
are only calculated for the top scored boxes.
Args:
Args:
prediction_dict: a dictionary holding "raw" prediction tensors:
prediction_dict: a dictionary holding "raw" prediction tensors:
1) refined_box_encodings: a 3-D tensor with shape
1) refined_box_encodings: a 3-D tensor with shape
...
@@ -851,16 +856,21 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -851,16 +856,21 @@ class FasterRCNNMetaArch(model.DetectionModel):
scope
=
self
.
second_stage_box_predictor_scope
,
scope
=
self
.
second_stage_box_predictor_scope
,
predict_boxes_and_classes
=
False
,
predict_boxes_and_classes
=
False
,
predict_auxiliary_outputs
=
True
)
predict_auxiliary_outputs
=
True
)
if
box_predictor
.
MASK_PREDICTIONS
in
box_predictions
:
if
box_predictor
.
MASK_PREDICTIONS
in
box_predictions
:
detection_masks
=
tf
.
squeeze
(
box_predictions
[
detection_masks
=
tf
.
squeeze
(
box_predictions
[
box_predictor
.
MASK_PREDICTIONS
],
axis
=
1
)
box_predictor
.
MASK_PREDICTIONS
],
axis
=
1
)
detection_masks
=
self
.
_gather_instance_masks
(
detection_masks
,
_
,
num_classes
,
mask_height
,
mask_width
=
(
detection_classes
)
detection_masks
.
get_shape
().
as_list
())
mask_height
=
tf
.
shape
(
detection_masks
)[
1
]
_
,
max_detection
=
detection_classes
.
get_shape
().
as_list
()
mask_width
=
tf
.
shape
(
detection_masks
)[
2
]
if
num_classes
>
1
:
detection_masks
=
self
.
_gather_instance_masks
(
detection_masks
,
detection_classes
)
prediction_dict
[
fields
.
DetectionResultFields
.
detection_masks
]
=
(
prediction_dict
[
fields
.
DetectionResultFields
.
detection_masks
]
=
(
tf
.
reshape
(
detection_masks
,
tf
.
reshape
(
detection_masks
,
[
batch_size
,
max_detection
,
mask_height
,
mask_width
]))
[
batch_size
,
max_detection
,
mask_height
,
mask_width
]))
return
prediction_dict
return
prediction_dict
def
_gather_instance_masks
(
self
,
instance_masks
,
classes
):
def
_gather_instance_masks
(
self
,
instance_masks
,
classes
):
...
@@ -874,16 +884,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -874,16 +884,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
Returns:
Returns:
masks: a 3-D float32 tensor with shape [K, mask_height, mask_width].
masks: a 3-D float32 tensor with shape [K, mask_height, mask_width].
"""
"""
_
,
num_classes
,
height
,
width
=
instance_masks
.
get_shape
().
as_list
()
k
=
tf
.
shape
(
instance_masks
)[
0
]
k
=
tf
.
shape
(
instance_masks
)[
0
]
num_mask_classes
=
tf
.
shape
(
instance_masks
)[
1
]
instance_masks
=
tf
.
reshape
(
instance_masks
,
[
-
1
,
height
,
width
])
instance_mask_height
=
tf
.
shape
(
instance_masks
)[
2
]
classes
=
tf
.
to_int32
(
tf
.
reshape
(
classes
,
[
-
1
]))
instance_mask_width
=
tf
.
shape
(
instance_masks
)[
3
]
gather_idx
=
tf
.
range
(
k
)
*
num_classes
+
classes
classes
=
tf
.
reshape
(
classes
,
[
-
1
])
return
tf
.
gather
(
instance_masks
,
gather_idx
)
instance_masks
=
tf
.
reshape
(
instance_masks
,
[
-
1
,
instance_mask_height
,
instance_mask_width
])
return
tf
.
gather
(
instance_masks
,
tf
.
range
(
k
)
*
num_mask_classes
+
tf
.
to_int32
(
classes
))
def
_extract_rpn_feature_maps
(
self
,
preprocessed_inputs
):
def
_extract_rpn_feature_maps
(
self
,
preprocessed_inputs
):
"""Extracts RPN features.
"""Extracts RPN features.
...
@@ -1815,11 +1821,18 @@ class FasterRCNNMetaArch(model.DetectionModel):
...
@@ -1815,11 +1821,18 @@ class FasterRCNNMetaArch(model.DetectionModel):
# Pad the prediction_masks with to add zeros for background class to be
# Pad the prediction_masks with to add zeros for background class to be
# consistent with class predictions.
# consistent with class predictions.
if
prediction_masks
.
get_shape
().
as_list
()[
1
]
==
1
:
# Class agnostic masks or masks for one-class prediction. Logic for
# both cases is the same since background predictions are ignored
# through the batch_mask_target_weights.
prediction_masks_masked_by_class_targets
=
prediction_masks
else
:
prediction_masks_with_background
=
tf
.
pad
(
prediction_masks_with_background
=
tf
.
pad
(
prediction_masks
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
prediction_masks
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
prediction_masks_masked_by_class_targets
=
tf
.
boolean_mask
(
prediction_masks_masked_by_class_targets
=
tf
.
boolean_mask
(
prediction_masks_with_background
,
prediction_masks_with_background
,
tf
.
greater
(
one_hot_flat_cls_targets_with_background
,
0
))
tf
.
greater
(
one_hot_flat_cls_targets_with_background
,
0
))
mask_height
=
prediction_masks
.
shape
[
2
].
value
mask_height
=
prediction_masks
.
shape
[
2
].
value
mask_width
=
prediction_masks
.
shape
[
3
].
value
mask_width
=
prediction_masks
.
shape
[
3
].
value
reshaped_prediction_masks
=
tf
.
reshape
(
reshaped_prediction_masks
=
tf
.
reshape
(
...
...
research/object_detection/meta_architectures/faster_rcnn_meta_arch_test.py
View file @
a4d9c3a0
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
...
@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
class
FasterRCNNMetaArchTest
(
class
FasterRCNNMetaArchTest
(
faster_rcnn_meta_arch_test_lib
.
FasterRCNNMetaArchTestBase
):
faster_rcnn_meta_arch_test_lib
.
FasterRCNNMetaArchTestBase
,
parameterized
.
TestCase
):
def
test_postprocess_second_stage_only_inference_mode_with_masks
(
self
):
def
test_postprocess_second_stage_only_inference_mode_with_masks
(
self
):
model
=
self
.
_build_model
(
model
=
self
.
_build_model
(
...
@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest(
...
@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest(
self
.
assertTrue
(
np
.
amax
(
detections_out
[
'detection_masks'
]
<=
1.0
))
self
.
assertTrue
(
np
.
amax
(
detections_out
[
'detection_masks'
]
<=
1.0
))
self
.
assertTrue
(
np
.
amin
(
detections_out
[
'detection_masks'
]
>=
0.0
))
self
.
assertTrue
(
np
.
amin
(
detections_out
[
'detection_masks'
]
>=
0.0
))
@
parameterized
.
parameters
(
{
'masks_are_class_agnostic'
:
False
},
{
'masks_are_class_agnostic'
:
True
},
)
def
test_predict_correct_shapes_in_inference_mode_three_stages_with_masks
(
def
test_predict_correct_shapes_in_inference_mode_three_stages_with_masks
(
self
):
self
,
masks_are_class_agnostic
):
batch_size
=
2
batch_size
=
2
image_size
=
10
image_size
=
10
max_num_proposals
=
8
max_num_proposals
=
8
...
@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest(
...
@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest(
is_training
=
False
,
is_training
=
False
,
number_of_stages
=
3
,
number_of_stages
=
3
,
second_stage_batch_size
=
2
,
second_stage_batch_size
=
2
,
predict_masks
=
True
)
predict_masks
=
True
,
masks_are_class_agnostic
=
masks_are_class_agnostic
)
preprocessed_inputs
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
input_shape
)
preprocessed_inputs
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
input_shape
)
_
,
true_image_shapes
=
model
.
preprocess
(
preprocessed_inputs
)
_
,
true_image_shapes
=
model
.
preprocess
(
preprocessed_inputs
)
result_tensor_dict
=
model
.
predict
(
preprocessed_inputs
,
result_tensor_dict
=
model
.
predict
(
preprocessed_inputs
,
...
@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest(
...
@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest(
self
.
assertAllEqual
(
tensor_dict_out
[
'detection_scores'
].
shape
,
[
2
,
5
])
self
.
assertAllEqual
(
tensor_dict_out
[
'detection_scores'
].
shape
,
[
2
,
5
])
self
.
assertAllEqual
(
tensor_dict_out
[
'num_detections'
].
shape
,
[
2
])
self
.
assertAllEqual
(
tensor_dict_out
[
'num_detections'
].
shape
,
[
2
])
@
parameterized
.
parameters
(
{
'masks_are_class_agnostic'
:
False
},
{
'masks_are_class_agnostic'
:
True
},
)
def
test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks
(
def
test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks
(
self
):
self
,
masks_are_class_agnostic
):
test_graph
=
tf
.
Graph
()
test_graph
=
tf
.
Graph
()
with
test_graph
.
as_default
():
with
test_graph
.
as_default
():
model
=
self
.
_build_model
(
model
=
self
.
_build_model
(
is_training
=
True
,
is_training
=
True
,
number_of_stages
=
2
,
number_of_stages
=
2
,
second_stage_batch_size
=
7
,
second_stage_batch_size
=
7
,
predict_masks
=
True
)
predict_masks
=
True
,
masks_are_class_agnostic
=
masks_are_class_agnostic
)
batch_size
=
2
batch_size
=
2
image_size
=
10
image_size
=
10
max_num_proposals
=
7
max_num_proposals
=
7
...
@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest(
...
@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest(
groundtruth_classes_list
)
groundtruth_classes_list
)
result_tensor_dict
=
model
.
predict
(
preprocessed_inputs
,
true_image_shapes
)
result_tensor_dict
=
model
.
predict
(
preprocessed_inputs
,
true_image_shapes
)
mask_shape_1
=
1
if
masks_are_class_agnostic
else
model
.
_num_classes
expected_shapes
=
{
expected_shapes
=
{
'rpn_box_predictor_features'
:
(
2
,
image_size
,
image_size
,
512
),
'rpn_box_predictor_features'
:
(
2
,
image_size
,
image_size
,
512
),
'rpn_features_to_crop'
:
(
2
,
image_size
,
image_size
,
3
),
'rpn_features_to_crop'
:
(
2
,
image_size
,
image_size
,
3
),
...
@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest(
...
@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest(
self
.
_get_box_classifier_features_shape
(
self
.
_get_box_classifier_features_shape
(
image_size
,
batch_size
,
max_num_proposals
,
initial_crop_size
,
image_size
,
batch_size
,
max_num_proposals
,
initial_crop_size
,
maxpool_stride
,
3
),
maxpool_stride
,
3
),
'mask_predictions'
:
(
2
*
max_num_proposals
,
2
,
14
,
14
)
'mask_predictions'
:
(
2
*
max_num_proposals
,
mask_shape_1
,
14
,
14
)
}
}
init_op
=
tf
.
global_variables_initializer
()
init_op
=
tf
.
global_variables_initializer
()
...
...
research/object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py
View file @
a4d9c3a0
...
@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
"""
"""
return
box_predictor_text_proto
return
box_predictor_text_proto
def
_add_mask_to_second_stage_box_predictor_text_proto
(
self
):
def
_add_mask_to_second_stage_box_predictor_text_proto
(
self
,
masks_are_class_agnostic
=
False
):
agnostic
=
'true'
if
masks_are_class_agnostic
else
'false'
box_predictor_text_proto
=
"""
box_predictor_text_proto
=
"""
mask_rcnn_box_predictor {
mask_rcnn_box_predictor {
predict_instance_masks: true
predict_instance_masks: true
masks_are_class_agnostic: """
+
agnostic
+
"""
mask_height: 14
mask_height: 14
mask_width: 14
mask_width: 14
conv_hyperparams {
conv_hyperparams {
...
@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
return
box_predictor_text_proto
return
box_predictor_text_proto
def
_get_second_stage_box_predictor
(
self
,
num_classes
,
is_training
,
def
_get_second_stage_box_predictor
(
self
,
num_classes
,
is_training
,
predict_masks
):
predict_masks
,
masks_are_class_agnostic
):
box_predictor_proto
=
box_predictor_pb2
.
BoxPredictor
()
box_predictor_proto
=
box_predictor_pb2
.
BoxPredictor
()
text_format
.
Merge
(
self
.
_get_second_stage_box_predictor_text_proto
(),
text_format
.
Merge
(
self
.
_get_second_stage_box_predictor_text_proto
(),
box_predictor_proto
)
box_predictor_proto
)
if
predict_masks
:
if
predict_masks
:
text_format
.
Merge
(
text_format
.
Merge
(
self
.
_add_mask_to_second_stage_box_predictor_text_proto
(),
self
.
_add_mask_to_second_stage_box_predictor_text_proto
(
masks_are_class_agnostic
),
box_predictor_proto
)
box_predictor_proto
)
return
box_predictor_builder
.
build
(
return
box_predictor_builder
.
build
(
...
@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
hard_mining
=
False
,
hard_mining
=
False
,
softmax_second_stage_classification_loss
=
True
,
softmax_second_stage_classification_loss
=
True
,
predict_masks
=
False
,
predict_masks
=
False
,
pad_to_max_dimension
=
None
):
pad_to_max_dimension
=
None
,
masks_are_class_agnostic
=
False
):
def
image_resizer_fn
(
image
,
masks
=
None
):
def
image_resizer_fn
(
image
,
masks
=
None
):
"""Fake image resizer function."""
"""Fake image resizer function."""
...
@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
...
@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
self
.
_get_second_stage_box_predictor
(
self
.
_get_second_stage_box_predictor
(
num_classes
=
num_classes
,
num_classes
=
num_classes
,
is_training
=
is_training
,
is_training
=
is_training
,
predict_masks
=
predict_masks
),
**
common_kwargs
)
predict_masks
=
predict_masks
,
masks_are_class_agnostic
=
masks_are_class_agnostic
),
**
common_kwargs
)
def
test_predict_gives_correct_shapes_in_inference_mode_first_stage_only
(
def
test_predict_gives_correct_shapes_in_inference_mode_first_stage_only
(
self
):
self
):
...
...
research/object_detection/protos/box_predictor.proto
View file @
a4d9c3a0
...
@@ -118,6 +118,7 @@ message MaskRCNNBoxPredictor {
...
@@ -118,6 +118,7 @@ message MaskRCNNBoxPredictor {
// The number of convolutions applied to image_features in the mask prediction
// The number of convolutions applied to image_features in the mask prediction
// branch.
// branch.
optional
int32
mask_prediction_num_conv_layers
=
11
[
default
=
2
];
optional
int32
mask_prediction_num_conv_layers
=
11
[
default
=
2
];
optional
bool
masks_are_class_agnostic
=
12
[
default
=
false
];
}
}
message
RfcnBoxPredictor
{
message
RfcnBoxPredictor
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment