"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "45d5af2416f53940e48100754bdfbb6360c4e586"
Commit 7723b206 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix pr

parent 5f71a455
......@@ -1923,38 +1923,41 @@ class CenterNetMaskTargetAssignerTest(test_case.TestCase):
expected_seg_target, segmentation_target)
def test_assign_detr(self):
def graph_fn(anchor_means, groundtruth_box_corners):
def graph_fn(anchor_means, groundtruth_box_corners,
groundtruth_labels, predicted_labels):
similarity_calc = region_similarity_calculator.DETRSimilarity()
matcher = hungarian_matcher.HungarianBipartiteMatcher()
box_coder = None
box_coder = box
target_assigner = targetassigner.TargetAssigner(
similarity_calc, matcher, box_coder)
anchors_boxlist = box_list.BoxList(anchor_means)
groundtruth_boxlist = box_list.BoxList(groundtruth_box_corners)
result = target_assigner.assign(
anchors_boxlist, groundtruth_boxlist, unmatched_class_label=None)
anchors_boxlist, groundtruth_boxlist, unmatched_class_label=None,
groundtruth_labels=groundtruth_labels, class_predictions=predicted_labels)
(cls_targets, cls_weights, reg_targets, reg_weights, _) = result
return (cls_targets, cls_weights, reg_targets, reg_weights)
anchor_means = np.array([[0.0, 0.0, 0.2, 0.2],
anchor_means = np.array([[0.0, 0.0, 0.4, 0.2],
[0.5, 0.5, 1.0, 0.8],
[0, 0.5, .5, 1.0]], dtype=np.float32)
[0.9, 0.5, 0.1, 1.0]], dtype=np.float32)
groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9]],
dtype=np.float32)
predicted_labels = np.array([[7, 3], [2, 9], [1, 5]])
groundtruth = np.array([[0, 1], [2, 9], [1, 5]])
predicted_labels = np.array([[-3, 3], [2, 9], [5, 1]])
groundtruth_labels = np.array([[0, 1], [0, 1]])
exp_cls_targets = [[1], [1], [0]]
exp_cls_weights = [[1], [1], [1]]
exp_reg_targets = [[0, 0, 0, 0],
[0, 0, -1, 1],
exp_reg_targets = [[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.9, 0.9],
[0, 0, 0, 0]]
exp_reg_weights = [1, 1, 0]
(cls_targets_out,
cls_weights_out, reg_targets_out, reg_weights_out) = self.execute(
graph_fn, [anchor_means, groundtruth_box_corners])
graph_fn, [anchor_means, groundtruth_box_corners,
groundtruth_labels, predicted_labels])
self.assertAllClose(cls_targets_out, exp_cls_targets)
self.assertAllClose(cls_weights_out, exp_cls_weights)
self.assertAllClose(reg_targets_out, exp_reg_targets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment