"vscode:/vscode.git/clone" did not exist on "70e4eb567fcad81c57598ab9ee6f81b4136ecca5"
Commit 3564e7ca authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix

parent 7b165eb4
......@@ -1921,6 +1921,46 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
np.testing.assert_array_almost_equal(
expected_seg_target, segmentation_target)
def test_assign_detr(self):
def graph_fn(anchor_means, groundtruth_box_corners):
similarity_calc = region_similarity_calculator.DETRSimilarity()
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5,
unmatched_threshold=0.5)
box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
target_assigner = targetassigner.TargetAssigner(
similarity_calc, matcher, box_coder)
anchors_boxlist = box_list.BoxList(anchor_means)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = target_assigner.assign(
anchors_boxlist, groundtruth_boxlist, unmatched_class_label=None)
(cls_targets, cls_weights, reg_targets, reg_weights, _) = result
return (cls_targets, cls_weights, reg_targets, reg_weights)
anchor_means = np.array([[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 1.0, 0.8],
[0, 0.5, .5, 1.0]], dtype=np.float32)
groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9]],
dtype=np.float32)
exp_cls_targets = [[1], [1], [0]]
exp_cls_weights = [[1], [1], [1]]
exp_reg_targets = [[0, 0, 0, 0],
[0, 0, -1, 1],
[0, 0, 0, 0]]
exp_reg_weights = [1, 1, 0]
(cls_targets_out,
cls_weights_out, reg_targets_out, reg_weights_out) = self.execute(
graph_fn, [anchor_means, groundtruth_box_corners])
self.assertAllClose(cls_targets_out, exp_cls_targets)
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_targets_out, exp_reg_targets)
self.assertAllClose(reg_weights_out, exp_reg_weights)
self.assertEqual(cls_targets_out.dtype, np.float32)
self.assertEqual(cls_weights_out.dtype, np.float32)
self.assertEqual(reg_targets_out.dtype, np.float32)
self.assertEqual(reg_weights_out.dtype, np.float32)
class CenterNetDensePoseTargetAssignerTest(test_case.TestCase):
......
......@@ -604,7 +604,7 @@ def train_loop(
return strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_losses, axis=None)
@tf.function
#@tf.function
def _dist_train_step(data_iterator):
"""A distributed train step."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment