Flatten.py 1.12 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import src.c2oObject as Node
from typing import List, Dict
import onnx


def get_attributes(layer) -> Dict:
    axis = layer.flatten_param.axis
    end_axis = layer.flatten_param.end_axis
    if end_axis != -1:
        print("not support end_axis param!")
        exit(-1)
    attributes = {
        "axis": axis
    }
    return attributes


def get_flatten_output_shape(input_shape: List,
                             attributes: Dict) -> List:
    shape = input_shape[0]
    input_prod = 1
    axis = attributes.get("axis")
    for i in range(axis, len(shape)):
        input_prod = input_prod * shape[i]

    output_shape = [shape[0:axis]+ [input_prod]]
    return output_shape

def create_flatten_node(layer, node_name : str,
                        input_names: List,
                        output_name: List,
                        input_shape: List) -> onnx.NodeProto:
    attributes = get_attributes(layer)

    output_shape = get_flatten_output_shape(input_shape, attributes)

    node = Node.c2oNode(layer, node_name, "Flatten", input_names,
                        output_name, input_shape, output_shape, attributes)
    return node