".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "fce54fd121d69b06c10463c036dc6b0e6f9845e1"
Commit b8e3e969 authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Changes to Enable KeypointRCNN ONNX Export (#1593)

* code changes to enable onnx export for keypoint rcnn

* add import

* fix copy paste error
parent 8aec85de
...@@ -338,6 +338,43 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -338,6 +338,43 @@ class ONNXExporterTester(unittest.TestCase):
model(images) model(images)
self.run_model(model, [(images,), (test_images,)]) self.run_model(model, [(images,), (test_images,)])
# Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
# (since jit_trace witll call _heatmaps_to_keypoints).
# @unittest.skip("Disable test until Resize bug fixed in ORT")
def test_heatmaps_to_keypoints(self):
# disable profiling
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
maps = torch.rand(10, 1, 26, 26)
rois = torch.rand(10, 4)
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
out = heatmaps_to_keypoints(maps, rois)
jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
out_trace = jit_trace(maps, rois)
assert torch.all(out[0].eq(out_trace[0]))
assert torch.all(out[1].eq(out_trace[1]))
maps2 = torch.rand(20, 2, 21, 21)
rois2 = torch.rand(20, 4)
from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
out2 = heatmaps_to_keypoints(maps2, rois2)
out_trace2 = jit_trace(maps2, rois2)
assert torch.all(out2[0].eq(out_trace2[0]))
assert torch.all(out2[1].eq(out_trace2[1]))
@unittest.skip("Disable test until Argmax is updated in ONNX")
def test_keypoint_rcnn(self):
images, test_images = self.get_test_images()
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(test_images)
self.run_model(model, [(images,), (test_images,)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from __future__ import division
import torch import torch
import torchvision import torchvision
...@@ -164,6 +165,58 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size): ...@@ -164,6 +165,58 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
return heatmaps, valid return heatmaps, valid
def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
widths_i, heights_i, offset_x_i, offset_y_i):
num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height
roi_map = torch.nn.functional.interpolate(
maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0]
w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = (pos % w)
y_int = ((pos - x_int) / w)
x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * \
width_correction.to(dtype=torch.float32)
y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * \
height_correction.to(dtype=torch.float32)
xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
xy_preds_i_2 = torch.ones((xy_preds_i_1.shape), dtype=torch.float32)
xy_preds_i = torch.stack([xy_preds_i_0.to(dtype=torch.float32),
xy_preds_i_1.to(dtype=torch.float32),
xy_preds_i_2.to(dtype=torch.float32)], 0)
# TODO: simplify when indexing without rank will be supported by ONNX
end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \
.index_select(2, x_int.to(dtype=torch.int64))[:num_keypoints, 0, 0]
return xy_preds_i, end_scores_i
@torch.jit.script
def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil,
widths, heights, offset_x, offset_y, num_keypoints):
xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
for i in range(int(rois.size(0))):
xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(maps, maps[i],
widths_ceil[i], heights_ceil[i],
widths[i], heights[i],
offset_x[i], offset_y[i])
xy_preds = torch.cat((xy_preds.to(dtype=torch.float32),
xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
end_scores = torch.cat((end_scores.to(dtype=torch.float32),
end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0)
return xy_preds, end_scores
def heatmaps_to_keypoints(maps, rois): def heatmaps_to_keypoints(maps, rois):
"""Extract predicted keypoint locations from heatmaps. Output has shape """Extract predicted keypoint locations from heatmaps. Output has shape
(#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
...@@ -185,6 +238,14 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -185,6 +238,14 @@ def heatmaps_to_keypoints(maps, rois):
heights_ceil = heights.ceil() heights_ceil = heights.ceil()
num_keypoints = maps.shape[1] num_keypoints = maps.shape[1]
if torchvision._is_tracing():
xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(maps, rois,
widths_ceil, heights_ceil, widths, heights,
offset_x, offset_y,
torch.scalar_tensor(num_keypoints, dtype=torch.int64))
return xy_preds.permute(0, 2, 1), end_scores
xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device) end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
for i in range(len(rois)): for i in range(len(rois)):
...@@ -244,7 +305,13 @@ def keypointrcnn_inference(x, boxes): ...@@ -244,7 +305,13 @@ def keypointrcnn_inference(x, boxes):
kp_probs = [] kp_probs = []
kp_scores = [] kp_scores = []
boxes_per_image = [len(box) for box in boxes] boxes_per_image = [box.size(0) for box in boxes]
if len(boxes_per_image) == 1:
# TODO : remove when dynamic split supported in ONNX
kp_prob, scores = heatmaps_to_keypoints(x, boxes[0])
return [kp_prob], [scores]
x2 = x.split(boxes_per_image, dim=0) x2 = x.split(boxes_per_image, dim=0)
for xx, bb in zip(x2, boxes): for xx, bb in zip(x2, boxes):
......
...@@ -154,8 +154,13 @@ def resize_keypoints(keypoints, original_size, new_size): ...@@ -154,8 +154,13 @@ def resize_keypoints(keypoints, original_size, new_size):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)) ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size))
ratio_h, ratio_w = ratios ratio_h, ratio_w = ratios
resized_data = keypoints.clone() resized_data = keypoints.clone()
resized_data[..., 0] *= ratio_w if torch._C._get_tracing_state():
resized_data[..., 1] *= ratio_h resized_data_0 = resized_data[:, :, 0] * ratio_w
resized_data_1 = resized_data[:, :, 1] * ratio_h
resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
else:
resized_data[..., 0] *= ratio_w
resized_data[..., 1] *= ratio_h
return resized_data return resized_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment