You need to sign in or sign up before continuing.
Unverified Commit de52437c authored by Ksenija Stanojevic's avatar Ksenija Stanojevic Committed by GitHub
Browse files

enable detection, no-detection test cases (#2272)

parent 37a0d8d6
...@@ -463,22 +463,22 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -463,22 +463,22 @@ class ONNXExporterTester(unittest.TestCase):
# TODO: # TODO:
# Enable test for dummy_image (no detection) once issue is # Enable test for dummy_image (no detection) once issue is
# _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed # _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed
# dummy_images = [torch.ones(3, 100, 100) * 0.3] dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = KeyPointRCNN() model = KeyPointRCNN()
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,)], self.run_model(model, [(images,), (test_images,), (dummy_images,)],
input_names=["images_tensors"], input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]}, dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed # TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images # Test exported model for an image with no detections on other images
# self.run_model(model, [(dummy_images,), (test_images,)], self.run_model(model, [(dummy_images,), (test_images,)],
# input_names=["images_tensors"], input_names=["images_tensors"],
# output_names=["outputs1", "outputs2", "outputs3", "outputs4"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
# dynamic_axes={"images_tensors": [0, 1, 2, 3]}, dynamic_axes={"images_tensors": [0, 1, 2, 3]},
# tolerate_small_mismatch=True) tolerate_small_mismatch=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -251,10 +251,9 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -251,10 +251,9 @@ class KeypointRCNNPredictor(nn.Module):
def forward(self, x): def forward(self, x):
x = self.kps_score_lowres(x) x = self.kps_score_lowres(x)
x = torch.nn.functional.interpolate( return torch.nn.functional.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
) )
return x
model_urls = { model_urls = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment