DetectionOutput.py 3.42 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Tencent is pleased to support the open source community by making TNN available.
#
# Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import onnx
from typing import *
from onnx import helper
from typing import *
import ctypes
import src.c2oObject as Node


def create_attribuates(layer) -> Dict:
    detection_output_param = layer.detection_output_param
    num_classes  = detection_output_param.num_classes
    share_location        = 1 if detection_output_param.share_location else 0
    background_label_id    = detection_output_param.background_label_id
    # NonMaximumSuppressionParameter
    nms_threshold = detection_output_param.nms_param.nms_threshold
    top_k = detection_output_param.nms_param.top_k
    eta = detection_output_param.nms_param.eta

    code_type              = detection_output_param.code_type
    variance_encoded_in_target = 1 if detection_output_param.variance_encoded_in_target else 0
    keep_top_k  = detection_output_param.keep_top_k
    confidence_threshold = detection_output_param.confidence_threshold
    visualize = 1 if detection_output_param.visualize else 0
    visualize_threshold = detection_output_param.visualize_threshold
    save_file = detection_output_param.save_file



    # TODO: SaveOutputParameter
    # save_output_param = detection_output_param.save_output_param
    # output_directory: str = save_output_param.output_directory
    # output_name_prefix: str = save_output_param.output_name_prefix
    # output_format: str = save_output_param.output_format
    # label_map_file: str = save_output_param.label_map_file
    # name_size_file: str = save_output_param.name_size_file
    # num_test_image: int = save_output_param.num_test_image



    attributes = {
        'num_classes'            : num_classes,
        'share_location'       : share_location,
        'background_label_id'  : background_label_id,
        'nms_threshold'        : nms_threshold,
        'top_k'                : top_k,
        'eta'                  : eta,
        'code_type'            : code_type,
        'variance_encoded_in_target' : variance_encoded_in_target,
        'keep_top_k'           : keep_top_k,
        'confidence_threshold' : confidence_threshold,
        'visualize'            : visualize,
        'visualize_threshold'  : visualize_threshold,
        'save_file'            : save_file
        }
    return attributes


def create_detection_output(layer,
                            node_name: str,
                            inputs_name: List[str],
                            outputs_name: List[str],
                            inputs_shape: List, ) -> onnx.NodeProto:

    attributes = create_attribuates(layer)

    outputs_shape = [[1, 1, 1, 7]]

    node = Node.c2oNode(layer, node_name, "DetectionOutput",
                        inputs_name, outputs_name,
                        inputs_shape, outputs_shape,
                        attributes)
    return node