dbnet_infer.py 3.48 KB
Newer Older
chenxj's avatar
chenxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import onnxruntime as rt
import  numpy as np
import time
import cv2

import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

from decode import  SegDetectorRepresenter

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)


def Singleton(cls):
    _instance = {}

    def _singleton(*args, **kargs):
        if cls not in _instance:
            _instance[cls] = cls(*args, **kargs)
        return _instance[cls]

    return _singleton


class SingletonType(type):
    def __init__(cls, *args, **kwargs):
        super(SingletonType, cls).__init__(*args, **kwargs)

    def __call__(cls, *args, **kwargs):
        obj = cls.__new__(cls, *args, **kwargs)
        cls.__init__(obj, *args, **kwargs)
        return obj


def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
    if isinstance(img_path, str):
        img_path = cv2.imread(img_path)
        # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
    img_path = img_path.copy()
    for point in result:
        point = point.astype(int)

        cv2.polylines(img_path, [point], True, color, thickness)
    return img_path


class DBNET(metaclass=SingletonType):
    def __init__(self, MODEL_PATH):
        self.sess = rt.InferenceSession(MODEL_PATH, providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider'])

        self.decode_handel = SegDetectorRepresenter()
    
    def process(self, img):

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        if h > w:
            resize_h = 1280
            ratio = float(1280) / h
            resize_w = int(w * ratio)
        else:
            resize_w = 1280
            ratio = float(1280) / w
            resize_h = int(h * ratio)

        try:
            if int(resize_w) <= 0 or int(resize_h) <= 0:
                return None, (None, None)
            img = cv2.resize(img, (int(resize_w), int(resize_h)))
        except:
            print(img.shape, resize_w, resize_h)
            sys.exit(0)
            
        img_pd_h = 1280
        img_pd_w = 1280
        
        padding_im = np.zeros((img_pd_h, img_pd_w, 3), dtype=np.uint8)
        top = int((img_pd_h - resize_h) / 2)
        left = int((img_pd_w -resize_w) / 2)
        padding_im[top:top + int(resize_h), left:left + int(resize_w), :] = img

        padding_im = padding_im.astype(np.float32)

        padding_im /= 255.0
        padding_im -= mean
        padding_im /= std
        padding_im = padding_im.transpose(2, 0, 1)
        transformed_image = np.expand_dims(padding_im, axis=0)
        out = self.sess.run(["out1"], {"input0": transformed_image.astype(np.float32)})
        box_list, score_list = self.decode_handel(out[0][0], h, w, resize_h, resize_w)
        if len(box_list) > 0:
            idx = box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0  # 去掉全为0的框
            box_list, score_list = box_list[idx], score_list[idx]
        else:
            box_list, score_list = [], []
        return box_list, score_list


if __name__ == "__main__":
    text_handle = DBNET(MODEL_PATH="./models/dbnet.onnx")
    # img = cv2.imread("/data/model-zoo/paddleocr/doc/imgs/1.jpg")
    img = cv2.imread("./images/1.jpg")
    print(img.shape)
    # box_list, score_list = text_handle.process(img, 512)
    box_list, score_list = text_handle.process(img)
    img = draw_bbox(img, box_list)
    cv2.imwrite("test.jpg", img)