Unverified Commit e96f2d5c authored by Negin Raoof's avatar Negin Raoof Committed by GitHub
Browse files

Fix for roi_align export (#1988)

parent 2875315d
......@@ -18,8 +18,10 @@ def _register_custom_op():
nms_out = g.op('NonMaxSuppression', boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), 1)
@parse_args('v', 'v', 'f', 'i', 'i', 'i')
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
@parse_args('v', 'v', 'f', 'i', 'i', 'i', 'i')
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
if(aligned):
raise RuntimeError('Unsupported: ONNX export of roi_align with aligned')
batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant',
value_t=torch.tensor([0], dtype=torch.long))), 1), False)
rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment