Commit a3f34adb authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Support dot-product in consistency metric.

PiperOrigin-RevId: 472760525
parent 18044f42
......@@ -449,15 +449,24 @@ def gaussian_pixel_similarity(a, b, theta):
return similarity
def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
"""Dilated cross pixel similarity as defined in [1].
def dotprod_pixel_similarity(a, b):
return tf.reduce_sum(a * b, axis=-1)
[1]: https://arxiv.org/abs/2012.02310
def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0,
method='gaussian'):
"""Dilated cross pixel similarity.
method supports 2 values
- 'gaussian' from https://arxiv.org/abs/2012.02310
- 'dotprod' computes the dot product between feature vector for similarity.
This assumes that the features are normalized.
Args:
feature_map: A float tensor of shape [batch_size, height, width, channels]
dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian.
method: str, either 'gaussian' or 'dotprod'.
Returns:
dilated_similarity: A tensor of shape [8, batch_size, height, width]
......@@ -465,7 +474,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis]
return gaussian_pixel_similarity(feature_map, neighbors, theta=theta)
if method == 'gaussian':
return gaussian_pixel_similarity(feature_map, neighbors, theta=theta)
elif method == 'dotprod':
return dotprod_pixel_similarity(feature_map, neighbors)
else:
raise ValueError('Unknown method for pixel sim %s' % method)
def dilated_cross_same_mask_label(instance_masks, dilation=2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment