Unverified Commit 747f406a authored by Ksenija Stanojevic's avatar Ksenija Stanojevic Committed by GitHub
Browse files

fix bug (#2312)

parent 883f1fb0
......@@ -446,25 +446,9 @@ class ONNXExporterTester(unittest.TestCase):
assert torch.all(out2[1].eq(out_trace2[1]))
def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
pretrained=True, min_size=200, max_size=300)
def forward(self, images):
output = self.model(images)
# TODO: The keypoints_scores require the use of Argmax that is updated in ONNX.
# For now we are testing all the output of KeypointRCNN except keypoints_scores.
# Enable When Argmax is updated in ONNX Runtime.
return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints']
images, test_images = self.get_test_images()
# TODO:
# Enable test for dummy_image (no detection) once issue is
# _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed
dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = KeyPointRCNN()
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,), (dummy_images,)],
......@@ -472,8 +456,7 @@ class ONNXExporterTester(unittest.TestCase):
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
......
......@@ -196,8 +196,12 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
xy_preds_i_2.to(dtype=torch.float32)], 0)
# TODO: simplify when indexing without rank will be supported by ONNX
base = num_keypoints * num_keypoints + num_keypoints + 1
ind = torch.arange(num_keypoints)
ind = ind.to(dtype=torch.int64) * base
end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \
.index_select(2, x_int.to(dtype=torch.int64))[:num_keypoints, 0, 0]
.index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64))
return xy_preds_i, end_scores_i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment