caffe2onnx.py 1.08 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from src.load_save_model import LoadCaffeModel, SaveOnnxModel
from src.caffe2onnx import Caffe2Onnx
from src.args_parser import parse_args
from src.utils import is_ssd_model

def main(args):
    caffe_graph_path = args.proto_file
    caffe_params_path = args.caffe_model_file

    pos_s = caffe_graph_path.rfind("/")
    if  pos_s == -1:
        pos_s = 0

    pos_dot = caffe_graph_path.rfind(".")
    onnx_name = caffe_graph_path[pos_s+1:pos_dot]
    save_path = caffe_graph_path[0:pos_dot] + '.onnx'
    if args.onnx_file is not None:
        save_path = args.onnx_file

    graph, params = LoadCaffeModel(caffe_graph_path,caffe_params_path)
    print('2. 开始进行模型转换')
    c2o = Caffe2Onnx(graph, params, onnx_name)
    print('3. 创建 onnx 模型')
    onnx_model = c2o.createOnnxModel()
    print('4. 保存 onnx 模型')
    # is_ssd = is_ssd_model(caffe_graph_path)
    # if is_ssd:
    SaveOnnxModel(onnx_model, save_path, need_polish=False)
    # else:
    #     SaveOnnxModel(onnx_model, save_path, need_polish=True)


if __name__ == '__main__':
    args = parse_args()
    main(args)