Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3324d816
Commit
3324d816
authored
Sep 02, 2021
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Sep 02, 2021
Browse files
Implement embedding based similarity mask head for DeepMAC.
PiperOrigin-RevId: 394512536
parent
0daae829
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
170 additions
and
22 deletions
+170
-22
research/object_detection/builders/losses_builder.py
research/object_detection/builders/losses_builder.py
+2
-1
research/object_detection/core/losses.py
research/object_detection/core/losses.py
+9
-2
research/object_detection/meta_architectures/deepmac_meta_arch.py
.../object_detection/meta_architectures/deepmac_meta_arch.py
+65
-7
research/object_detection/meta_architectures/deepmac_meta_arch_test.py
...ct_detection/meta_architectures/deepmac_meta_arch_test.py
+90
-12
research/object_detection/protos/losses.proto
research/object_detection/protos/losses.proto
+4
-0
No files found.
research/object_detection/builders/losses_builder.py
View file @
3324d816
...
@@ -263,7 +263,8 @@ def _build_classification_loss(loss_config):
...
@@ -263,7 +263,8 @@ def _build_classification_loss(loss_config):
elif
loss_type
==
'weighted_dice_classification_loss'
:
elif
loss_type
==
'weighted_dice_classification_loss'
:
config
=
loss_config
.
weighted_dice_classification_loss
config
=
loss_config
.
weighted_dice_classification_loss
return
losses
.
WeightedDiceClassificationLoss
(
return
losses
.
WeightedDiceClassificationLoss
(
squared_normalization
=
config
.
squared_normalization
)
squared_normalization
=
config
.
squared_normalization
,
is_prediction_probability
=
config
.
is_prediction_probability
)
else
:
else
:
raise
ValueError
(
'Empty loss config.'
)
raise
ValueError
(
'Empty loss config.'
)
research/object_detection/core/losses.py
View file @
3324d816
...
@@ -286,15 +286,19 @@ class WeightedDiceClassificationLoss(Loss):
...
@@ -286,15 +286,19 @@ class WeightedDiceClassificationLoss(Loss):
"""
"""
def
__init__
(
self
,
squared_normalization
):
def
__init__
(
self
,
squared_normalization
,
is_prediction_probability
=
False
):
"""Initializes the loss object.
"""Initializes the loss object.
Args:
Args:
squared_normalization: boolean, if set, we square the probabilities in the
squared_normalization: boolean, if set, we square the probabilities in the
denominator term used for normalization.
denominator term used for normalization.
is_prediction_probability: boolean, whether or not the input
prediction_tensor represents a probability. If false, it is
first converted to a probability by applying sigmoid.
"""
"""
self
.
_squared_normalization
=
squared_normalization
self
.
_squared_normalization
=
squared_normalization
self
.
is_prediction_probability
=
is_prediction_probability
super
(
WeightedDiceClassificationLoss
,
self
).
__init__
()
super
(
WeightedDiceClassificationLoss
,
self
).
__init__
()
def
_compute_loss
(
self
,
def
_compute_loss
(
self
,
...
@@ -332,7 +336,10 @@ class WeightedDiceClassificationLoss(Loss):
...
@@ -332,7 +336,10 @@ class WeightedDiceClassificationLoss(Loss):
tf
.
shape
(
prediction_tensor
)[
2
]),
tf
.
shape
(
prediction_tensor
)[
2
]),
[
1
,
1
,
-
1
])
[
1
,
1
,
-
1
])
prob_tensor
=
tf
.
nn
.
sigmoid
(
prediction_tensor
)
if
self
.
is_prediction_probability
:
prob_tensor
=
prediction_tensor
else
:
prob_tensor
=
tf
.
nn
.
sigmoid
(
prediction_tensor
)
if
self
.
_squared_normalization
:
if
self
.
_squared_normalization
:
prob_tensor
=
tf
.
pow
(
prob_tensor
,
2
)
prob_tensor
=
tf
.
pow
(
prob_tensor
,
2
)
...
...
research/object_detection/meta_architectures/deepmac_meta_arch.py
View file @
3324d816
...
@@ -36,7 +36,8 @@ class DeepMACParams(
...
@@ -36,7 +36,8 @@ class DeepMACParams(
'allowed_masked_classes_ids'
,
'mask_size'
,
'mask_num_subsamples'
,
'allowed_masked_classes_ids'
,
'mask_size'
,
'mask_num_subsamples'
,
'use_xy'
,
'network_type'
,
'use_instance_embedding'
,
'num_init_channels'
,
'use_xy'
,
'network_type'
,
'use_instance_embedding'
,
'num_init_channels'
,
'predict_full_resolution_masks'
,
'postprocess_crop_size'
,
'predict_full_resolution_masks'
,
'postprocess_crop_size'
,
'max_roi_jitter_ratio'
,
'roi_jitter_mode'
,
'box_consistency_loss_weight'
'max_roi_jitter_ratio'
,
'roi_jitter_mode'
,
'box_consistency_loss_weight'
,
])):
])):
"""Class holding the DeepMAC network configutration."""
"""Class holding the DeepMAC network configutration."""
...
@@ -125,6 +126,9 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
...
@@ -125,6 +126,9 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise
ValueError
(
'Mask size must be set.'
)
raise
ValueError
(
'Mask size must be set.'
)
return
FullyConnectedMaskHead
(
num_init_channels
,
mask_size
)
return
FullyConnectedMaskHead
(
num_init_channels
,
mask_size
)
elif
name
==
'embedding_distance_probability'
:
return
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
)
elif
name
.
startswith
(
'resnet'
):
elif
name
.
startswith
(
'resnet'
):
return
ResNetMaskNetwork
(
name
,
num_init_channels
)
return
ResNetMaskNetwork
(
name
,
num_init_channels
)
...
@@ -262,6 +266,25 @@ def fill_boxes(boxes, height, width):
...
@@ -262,6 +266,25 @@ def fill_boxes(boxes, height, width):
return
tf
.
cast
(
filled_boxes
,
tf
.
float32
)
return
tf
.
cast
(
filled_boxes
,
tf
.
float32
)
def
embedding_distance_to_probability
(
x
,
y
):
"""Compute probability based on pixel-wise embedding distance.
Args:
x: [num_instances, height, width, dimension] float tensor input.
y: [num_instances, height, width, dimension] or
[num_instances, 1, 1, dimension] float tensor input. When the height
and width dimensions are 1, TF will broadcast it.
Returns:
dist: [num_instances, height, width, 1] A float tensor returning
the per-pixel probability. Pixels whose embeddings are close in
euclidean distance get a probability of close to 1.
"""
diff
=
x
-
y
squared_dist
=
tf
.
reduce_sum
(
diff
*
diff
,
axis
=
3
,
keepdims
=
True
)
return
tf
.
exp
(
-
squared_dist
)
class
ResNetMaskNetwork
(
tf
.
keras
.
layers
.
Layer
):
class
ResNetMaskNetwork
(
tf
.
keras
.
layers
.
Layer
):
"""A small wrapper around ResNet blocks to predict masks."""
"""A small wrapper around ResNet blocks to predict masks."""
...
@@ -366,8 +389,18 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
...
@@ -366,8 +389,18 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
network_type
,
num_init_channels
,
mask_size
)
network_type
,
num_init_channels
,
mask_size
)
self
.
_use_instance_embedding
=
use_instance_embedding
self
.
_use_instance_embedding
=
use_instance_embedding
self
.
project_out
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_network_type
=
network_type
filters
=
1
,
kernel_size
=
1
,
activation
=
None
)
if
(
self
.
_use_instance_embedding
and
(
self
.
_network_type
==
'embedding_distance_probability'
)):
raise
ValueError
((
'Cannot feed instance embedding to mask head when '
'computing distance from instance embedding.'
))
if
network_type
==
'embedding_distance_probability'
:
self
.
project_out
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
)
else
:
self
.
project_out
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
1
,
kernel_size
=
1
,
activation
=
None
)
def
__call__
(
self
,
instance_embedding
,
pixel_embedding
,
training
):
def
__call__
(
self
,
instance_embedding
,
pixel_embedding
,
training
):
"""Returns mask logits given object center and spatial embeddings.
"""Returns mask logits given object center and spatial embeddings.
...
@@ -388,10 +421,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
...
@@ -388,10 +421,9 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
height
=
tf
.
shape
(
pixel_embedding
)[
1
]
height
=
tf
.
shape
(
pixel_embedding
)[
1
]
width
=
tf
.
shape
(
pixel_embedding
)[
2
]
width
=
tf
.
shape
(
pixel_embedding
)[
2
]
instance_embedding
=
instance_embedding
[:,
tf
.
newaxis
,
tf
.
newaxis
,
:]
instance_embedding
=
tf
.
tile
(
instance_embedding
,
[
1
,
height
,
width
,
1
])
if
self
.
_use_instance_embedding
:
if
self
.
_use_instance_embedding
:
instance_embedding
=
instance_embedding
[:,
tf
.
newaxis
,
tf
.
newaxis
,
:]
instance_embedding
=
tf
.
tile
(
instance_embedding
,
[
1
,
height
,
width
,
1
])
inputs
=
tf
.
concat
([
pixel_embedding
,
instance_embedding
],
axis
=
3
)
inputs
=
tf
.
concat
([
pixel_embedding
,
instance_embedding
],
axis
=
3
)
else
:
else
:
inputs
=
pixel_embedding
inputs
=
pixel_embedding
...
@@ -400,6 +432,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
...
@@ -400,6 +432,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
if
isinstance
(
out
,
list
):
if
isinstance
(
out
,
list
):
out
=
out
[
-
1
]
out
=
out
[
-
1
]
if
self
.
_network_type
==
'embedding_distance_probability'
:
instance_embedding
=
instance_embedding
[:,
tf
.
newaxis
,
tf
.
newaxis
,
:]
out
=
embedding_distance_to_probability
(
instance_embedding
,
out
)
if
out
.
shape
[
-
1
]
>
1
:
if
out
.
shape
[
-
1
]
>
1
:
out
=
self
.
project_out
(
out
)
out
=
self
.
project_out
(
out
)
...
@@ -466,6 +502,25 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -466,6 +502,25 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
if
self
.
_deepmac_params
.
mask_num_subsamples
>
0
:
if
self
.
_deepmac_params
.
mask_num_subsamples
>
0
:
raise
ValueError
(
'Subsampling masks is currently not supported.'
)
raise
ValueError
(
'Subsampling masks is currently not supported.'
)
if
self
.
_deepmac_params
.
network_type
==
'embedding_distance_probability'
:
if
self
.
_deepmac_params
.
use_xy
:
raise
ValueError
(
'Cannot use x/y coordinates when using embedding distance.'
)
pixel_embedding_dim
=
self
.
_deepmac_params
.
pixel_embedding_dim
dim
=
self
.
_deepmac_params
.
dim
if
dim
!=
pixel_embedding_dim
:
raise
ValueError
(
'When using embedding distance mask head, '
f
'pixel_embedding_dim(
{
pixel_embedding_dim
}
) '
f
'must be same as dim(
{
dim
}
).'
)
loss
=
self
.
_deepmac_params
.
classification_loss
if
((
not
isinstance
(
loss
,
losses
.
WeightedDiceClassificationLoss
))
or
(
not
loss
.
is_prediction_probability
)):
raise
ValueError
(
'Only dice loss with is_prediction_probability=true '
'is supported with embedding distance mask head.'
)
super
(
DeepMACMetaArch
,
self
).
__init__
(
super
(
DeepMACMetaArch
,
self
).
__init__
(
is_training
=
is_training
,
add_summaries
=
add_summaries
,
is_training
=
is_training
,
add_summaries
=
add_summaries
,
num_classes
=
num_classes
,
feature_extractor
=
feature_extractor
,
num_classes
=
num_classes
,
feature_extractor
=
feature_extractor
,
...
@@ -909,7 +964,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
...
@@ -909,7 +964,10 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits
=
crop_masks_within_boxes
(
mask_logits
=
crop_masks_within_boxes
(
mask_logits
,
boxes
,
self
.
_deepmac_params
.
postprocess_crop_size
)
mask_logits
,
boxes
,
self
.
_deepmac_params
.
postprocess_crop_size
)
masks_prob
=
tf
.
nn
.
sigmoid
(
mask_logits
)
if
self
.
_deepmac_params
.
network_type
==
'embedding_distance_probability'
:
masks_prob
=
mask_logits
else
:
masks_prob
=
tf
.
nn
.
sigmoid
(
mask_logits
)
return
masks_prob
return
masks_prob
...
...
research/object_detection/meta_architectures/deepmac_meta_arch_test.py
View file @
3324d816
...
@@ -61,7 +61,10 @@ class MockMaskNet(tf.keras.layers.Layer):
...
@@ -61,7 +61,10 @@ class MockMaskNet(tf.keras.layers.Layer):
def
build_meta_arch
(
predict_full_resolution_masks
=
False
,
use_dice_loss
=
False
,
def
build_meta_arch
(
predict_full_resolution_masks
=
False
,
use_dice_loss
=
False
,
mask_num_subsamples
=-
1
):
use_instance_embedding
=
True
,
mask_num_subsamples
=-
1
,
network_type
=
'hourglass10'
,
use_xy
=
True
,
pixel_embedding_dim
=
2
,
dice_loss_prediction_probability
=
False
):
"""Builds the DeepMAC meta architecture."""
"""Builds the DeepMAC meta architecture."""
feature_extractor
=
DummyFeatureExtractor
(
feature_extractor
=
DummyFeatureExtractor
(
...
@@ -84,7 +87,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
...
@@ -84,7 +87,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
use_labeled_classes
=
False
)
use_labeled_classes
=
False
)
if
use_dice_loss
:
if
use_dice_loss
:
classification_loss
=
losses
.
WeightedDiceClassificationLoss
(
False
)
classification_loss
=
losses
.
WeightedDiceClassificationLoss
(
squared_normalization
=
False
,
is_prediction_probability
=
dice_loss_prediction_probability
)
else
:
else
:
classification_loss
=
losses
.
WeightedSigmoidClassificationLoss
()
classification_loss
=
losses
.
WeightedSigmoidClassificationLoss
()
...
@@ -92,13 +97,13 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
...
@@ -92,13 +97,13 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
classification_loss
=
classification_loss
,
classification_loss
=
classification_loss
,
dim
=
8
,
dim
=
8
,
task_loss_weight
=
1.0
,
task_loss_weight
=
1.0
,
pixel_embedding_dim
=
2
,
pixel_embedding_dim
=
pixel_embedding_dim
,
allowed_masked_classes_ids
=
[],
allowed_masked_classes_ids
=
[],
mask_size
=
16
,
mask_size
=
16
,
mask_num_subsamples
=
mask_num_subsamples
,
mask_num_subsamples
=
mask_num_subsamples
,
use_xy
=
True
,
use_xy
=
use_xy
,
network_type
=
'hourglass10'
,
network_type
=
network_type
,
use_instance_embedding
=
True
,
use_instance_embedding
=
use_instance_embedding
,
num_init_channels
=
8
,
num_init_channels
=
8
,
predict_full_resolution_masks
=
predict_full_resolution_masks
,
predict_full_resolution_masks
=
predict_full_resolution_masks
,
postprocess_crop_size
=
128
,
postprocess_crop_size
=
128
,
...
@@ -125,7 +130,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
...
@@ -125,7 +130,7 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False,
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
DeepMACUtilsTest
(
tf
.
test
.
TestCase
):
class
DeepMACUtilsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_subsample_trivial
(
self
):
def
test_subsample_trivial
(
self
):
"""Test subsampling masks."""
"""Test subsampling masks."""
...
@@ -169,12 +174,22 @@ class DeepMACUtilsTest(tf.test.TestCase):
...
@@ -169,12 +174,22 @@ class DeepMACUtilsTest(tf.test.TestCase):
features
,
boxes
,
32
)
features
,
boxes
,
32
)
self
.
assertEqual
(
output
.
shape
,
(
5
,
32
,
32
,
7
))
self
.
assertEqual
(
output
.
shape
,
(
5
,
32
,
32
,
7
))
def
test_embedding_distance_prob_shape
(
self
):
dist
=
deepmac_meta_arch
.
embedding_distance_to_probability
(
tf
.
ones
((
4
,
32
,
32
,
8
)),
tf
.
zeros
((
4
,
32
,
32
,
8
)))
self
.
assertEqual
(
dist
.
shape
,
(
4
,
32
,
32
,
1
))
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
parameterized
.
parameters
([
1e-20
,
1e20
])
class
DeepMACMetaArchTest
(
tf
.
test
.
TestCase
):
def
test_embedding_distance_prob_value
(
self
,
value
):
dist
=
deepmac_meta_arch
.
embedding_distance_to_probability
(
tf
.
zeros
((
1
,
1
,
1
,
8
)),
value
+
tf
.
zeros
((
1
,
1
,
1
,
8
))).
numpy
()
max_float
=
np
.
finfo
(
dist
.
dtype
).
max
self
.
assertLess
(
dist
.
max
(),
max_float
)
self
.
assertGreater
(
dist
.
max
(),
-
max_float
)
def
setUp
(
self
):
# pylint:disable=g-missing-super-call
self
.
model
=
build_meta_arch
()
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
DeepMACMaskHeadTest
(
tf
.
test
.
TestCase
):
def
test_mask_network
(
self
):
def
test_mask_network
(
self
):
net
=
deepmac_meta_arch
.
MaskHeadNetwork
(
'hourglass10'
,
8
)
net
=
deepmac_meta_arch
.
MaskHeadNetwork
(
'hourglass10'
,
8
)
...
@@ -203,6 +218,38 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -203,6 +218,38 @@ class DeepMACMetaArchTest(tf.test.TestCase):
out
=
call_func
(
tf
.
zeros
((
2
,
4
)),
tf
.
zeros
((
2
,
32
,
32
,
16
)),
training
=
True
)
out
=
call_func
(
tf
.
zeros
((
2
,
4
)),
tf
.
zeros
((
2
,
32
,
32
,
16
)),
training
=
True
)
self
.
assertEqual
(
out
.
shape
,
(
2
,
32
,
32
))
self
.
assertEqual
(
out
.
shape
,
(
2
,
32
,
32
))
def
test_mask_network_embedding_distance_zero_dist
(
self
):
net
=
deepmac_meta_arch
.
MaskHeadNetwork
(
'embedding_distance_probability'
,
num_init_channels
=
8
,
use_instance_embedding
=
False
)
call_func
=
tf
.
function
(
net
.
__call__
)
out
=
call_func
(
tf
.
zeros
((
2
,
7
)),
tf
.
zeros
((
2
,
32
,
32
,
7
)),
training
=
True
)
self
.
assertEqual
(
out
.
shape
,
(
2
,
32
,
32
))
self
.
assertAllGreater
(
out
.
numpy
(),
-
np
.
inf
)
self
.
assertAllLess
(
out
.
numpy
(),
np
.
inf
)
def
test_mask_network_embedding_distance_small_dist
(
self
):
net
=
deepmac_meta_arch
.
MaskHeadNetwork
(
'embedding_distance_probability'
,
num_init_channels
=-
1
,
use_instance_embedding
=
False
)
call_func
=
tf
.
function
(
net
.
__call__
)
out
=
call_func
(
1e6
+
tf
.
zeros
((
2
,
7
)),
tf
.
zeros
((
2
,
32
,
32
,
7
)),
training
=
True
)
self
.
assertEqual
(
out
.
shape
,
(
2
,
32
,
32
))
self
.
assertAllGreater
(
out
.
numpy
(),
-
np
.
inf
)
self
.
assertAllLess
(
out
.
numpy
(),
np
.
inf
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
DeepMACMetaArchTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
# pylint:disable=g-missing-super-call
self
.
model
=
build_meta_arch
()
def
test_get_mask_head_input
(
self
):
def
test_get_mask_head_input
(
self
):
boxes
=
tf
.
constant
([[
0.
,
0.
,
0.25
,
0.25
],
[
0.75
,
0.75
,
1.0
,
1.0
]],
boxes
=
tf
.
constant
([[
0.
,
0.
,
0.25
,
0.25
],
[
0.75
,
0.75
,
1.0
,
1.0
]],
...
@@ -349,6 +396,37 @@ class DeepMACMetaArchTest(tf.test.TestCase):
...
@@ -349,6 +396,37 @@ class DeepMACMetaArchTest(tf.test.TestCase):
prob
=
tf
.
nn
.
sigmoid
(
0.9
).
numpy
()
prob
=
tf
.
nn
.
sigmoid
(
0.9
).
numpy
()
self
.
assertAllClose
(
masks
,
prob
*
np
.
ones
((
2
,
3
,
16
,
16
)))
self
.
assertAllClose
(
masks
,
prob
*
np
.
ones
((
2
,
3
,
16
,
16
)))
def
test_postprocess_emb_dist
(
self
):
model
=
build_meta_arch
(
network_type
=
'embedding_distance_probability'
,
use_instance_embedding
=
False
,
use_xy
=
False
,
pixel_embedding_dim
=
8
,
use_dice_loss
=
True
,
dice_loss_prediction_probability
=
True
)
boxes
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
boxes
[:,
:,
[
0
,
2
]]
=
0.0
boxes
[:,
:,
[
1
,
3
]]
=
8.0
boxes
=
tf
.
constant
(
boxes
)
masks
=
model
.
_postprocess_masks
(
boxes
,
tf
.
zeros
((
2
,
32
,
32
,
2
)),
tf
.
zeros
((
2
,
32
,
32
,
2
)))
self
.
assertEqual
(
masks
.
shape
,
(
2
,
3
,
16
,
16
))
def
test_postprocess_emb_dist_fullres
(
self
):
model
=
build_meta_arch
(
network_type
=
'embedding_distance_probability'
,
predict_full_resolution_masks
=
True
,
use_instance_embedding
=
False
,
pixel_embedding_dim
=
8
,
use_xy
=
False
,
use_dice_loss
=
True
,
dice_loss_prediction_probability
=
True
)
boxes
=
np
.
zeros
((
2
,
3
,
4
),
dtype
=
np
.
float32
)
boxes
=
tf
.
constant
(
boxes
)
masks
=
model
.
_postprocess_masks
(
boxes
,
tf
.
zeros
((
2
,
32
,
32
,
2
)),
tf
.
zeros
((
2
,
32
,
32
,
2
)))
self
.
assertEqual
(
masks
.
shape
,
(
2
,
3
,
128
,
128
))
def
test_postprocess_no_crop_resize_shape
(
self
):
def
test_postprocess_no_crop_resize_shape
(
self
):
model
=
build_meta_arch
(
predict_full_resolution_masks
=
True
)
model
=
build_meta_arch
(
predict_full_resolution_masks
=
True
)
...
@@ -494,7 +572,7 @@ class FullyConnectedMaskHeadTest(tf.test.TestCase):
...
@@ -494,7 +572,7 @@ class FullyConnectedMaskHeadTest(tf.test.TestCase):
class
ResNetMaskHeadTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ResNetMaskHeadTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
([
'resnet4'
,
'resnet8'
,
'resnet20'
])
@
parameterized
.
parameters
([
'resnet4'
,
'resnet8'
,
'resnet20'
])
def
test_
pass
(
self
,
name
):
def
test_
forward
(
self
,
name
):
net
=
deepmac_meta_arch
.
ResNetMaskNetwork
(
name
,
8
)
net
=
deepmac_meta_arch
.
ResNetMaskNetwork
(
name
,
8
)
out
=
net
(
tf
.
zeros
((
3
,
32
,
32
,
16
)))
out
=
net
(
tf
.
zeros
((
3
,
32
,
32
,
16
)))
self
.
assertEqual
(
out
.
shape
[:
3
],
(
3
,
32
,
32
))
self
.
assertEqual
(
out
.
shape
[:
3
],
(
3
,
32
,
32
))
...
...
research/object_detection/protos/losses.proto
View file @
3324d816
...
@@ -231,6 +231,10 @@ message WeightedDiceClassificationLoss {
...
@@ -231,6 +231,10 @@ message WeightedDiceClassificationLoss {
// If set, we square the probabilities in the denominator term used for
// If set, we square the probabilities in the denominator term used for
// normalization.
// normalization.
optional
bool
squared_normalization
=
1
[
default
=
false
];
optional
bool
squared_normalization
=
1
[
default
=
false
];
// Whether or not the input prediction to the loss function is a
// probability. If not, the input is to be interpreted as logit
optional
bool
is_prediction_probability
=
2
[
default
=
false
];
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment