_register_onnx_ops.py 4.08 KB
Newer Older
1
import sys
2
import warnings
3

4
import torch
Justin Chu's avatar
Justin Chu committed
5
6
from torch.onnx import symbolic_opset11 as opset11
from torch.onnx.symbolic_helper import parse_args
7

Justin Chu's avatar
Justin Chu committed
8
9
10
_ONNX_OPSET_VERSION_11 = 11
_ONNX_OPSET_VERSION_16 = 16
BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11
11

12

Justin Chu's avatar
Justin Chu committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
    boxes = opset11.unsqueeze(g, boxes, 0)
    scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0)
    max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
    iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))

    # Cast boxes and scores to float32 in case they are float64 inputs
    nms_out = g.op(
        "NonMaxSuppression",
        g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
        g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
        max_output_per_class,
        iou_threshold,
    )
    return opset11.squeeze(
        g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1
    )


def _process_batch_indices_for_roi_align(g, rois):
    indices = opset11.squeeze(
        g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1
    )
    return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)


def _process_rois_for_roi_align(g, rois):
    return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
42

Justin Chu's avatar
Justin Chu committed
43
44
45
46
47
48

def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
    if sampling_ratio < 0:
        warnings.warn(
            "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
            "The model will be exported with a sampling_ratio of 0."
49
        )
Justin Chu's avatar
Justin Chu committed
50
51
52
        sampling_ratio = 0
    return sampling_ratio

53

Justin Chu's avatar
Justin Chu committed
54
55
56
57
58
59
60
61
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
    batch_indices = _process_batch_indices_for_roi_align(g, rois)
    rois = _process_rois_for_roi_align(g, rois)
    if aligned:
        warnings.warn(
            "ROIAlign with aligned=True is only supported in opset >= 16. "
            "Please export with opset 16 or higher, or use aligned=False."
62
        )
Justin Chu's avatar
Justin Chu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
    return g.op(
        "RoiAlign",
        input,
        rois,
        batch_indices,
        spatial_scale_f=spatial_scale,
        output_height_i=pooled_height,
        output_width_i=pooled_width,
        sampling_ratio_i=sampling_ratio,
    )


@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
    batch_indices = _process_batch_indices_for_roi_align(g, rois)
    rois = _process_rois_for_roi_align(g, rois)
    coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
    sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
    return g.op(
        "RoiAlign",
        input,
        rois,
        batch_indices,
        coordinate_transformation_mode_s=coordinate_transformation_mode,
        spatial_scale_f=spatial_scale,
        output_height_i=pooled_height,
        output_width_i=pooled_width,
        sampling_ratio_i=sampling_ratio,
    )
93

94

Justin Chu's avatar
Justin Chu committed
95
96
97
98
99
100
101
102
103
104
105
106
107
@parse_args("v", "v", "f", "i", "i")
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
    roi_pool = g.op(
        "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
    )
    return roi_pool, None


def _register_custom_op():
    torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11)
    torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11)
    torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16)
    torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11)