ocr_det.cpp 6.07 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/ocr_det.h>
MissPenguin's avatar
MissPenguin committed
16

littletomatodonkey's avatar
littletomatodonkey committed
17
18
19

namespace PaddleOCR {

littletomatodonkey's avatar
littletomatodonkey committed
20
void DBDetector::LoadModel(const std::string &model_dir) {
LDOUBLEV's avatar
LDOUBLEV committed
21
22
  //   AnalysisConfig config;
  paddle_infer::Config config;
WenmuZhou's avatar
WenmuZhou committed
23
24
  config.SetModel(model_dir + "/inference.pdmodel",
                  model_dir + "/inference.pdiparams");
littletomatodonkey's avatar
littletomatodonkey committed
25

littletomatodonkey's avatar
littletomatodonkey committed
26
27
  if (this->use_gpu_) {
    config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
28
29
30
31
32
33
    if (this->use_tensorrt_) {
      config.EnableTensorRtEngine(
          1 << 20, 10, 3,
          this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
                          : paddle_infer::Config::Precision::kFloat32,
          false, false);
LDOUBLEV's avatar
LDOUBLEV committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
      std::map<std::string, std::vector<int>> min_input_shape = {
          {"x", {1, 3, 50, 50}},
          {"conv2d_92.tmp_0", {1, 96, 20, 20}},
          {"conv2d_91.tmp_0", {1, 96, 10, 10}},
          {"nearest_interp_v2_1.tmp_0", {1, 96, 10, 10}},
          {"nearest_interp_v2_2.tmp_0", {1, 96, 20, 20}},
          {"nearest_interp_v2_3.tmp_0", {1, 24, 20, 20}},
          {"nearest_interp_v2_4.tmp_0", {1, 24, 20, 20}},
          {"nearest_interp_v2_5.tmp_0", {1, 24, 20, 20}},
          {"elementwise_add_7", {1, 56, 2, 2}},
          {"nearest_interp_v2_0.tmp_0", {1, 96, 2, 2}}};
      std::map<std::string, std::vector<int>> max_input_shape = {
          {"x", {1, 3, this->max_side_len_, this->max_side_len_}},
          {"conv2d_92.tmp_0", {1, 96, 400, 400}},
          {"conv2d_91.tmp_0", {1, 96, 200, 200}},
          {"nearest_interp_v2_1.tmp_0", {1, 96, 200, 200}},
          {"nearest_interp_v2_2.tmp_0", {1, 96, 400, 400}},
          {"nearest_interp_v2_3.tmp_0", {1, 24, 400, 400}},
          {"nearest_interp_v2_4.tmp_0", {1, 24, 400, 400}},
          {"nearest_interp_v2_5.tmp_0", {1, 24, 400, 400}},
          {"elementwise_add_7", {1, 56, 400, 400}},
          {"nearest_interp_v2_0.tmp_0", {1, 96, 400, 400}}};
      std::map<std::string, std::vector<int>> opt_input_shape = {
          {"x", {1, 3, 640, 640}},
          {"conv2d_92.tmp_0", {1, 96, 160, 160}},
          {"conv2d_91.tmp_0", {1, 96, 80, 80}},
          {"nearest_interp_v2_1.tmp_0", {1, 96, 80, 80}},
          {"nearest_interp_v2_2.tmp_0", {1, 96, 160, 160}},
          {"nearest_interp_v2_3.tmp_0", {1, 24, 160, 160}},
          {"nearest_interp_v2_4.tmp_0", {1, 24, 160, 160}},
          {"nearest_interp_v2_5.tmp_0", {1, 24, 160, 160}},
          {"elementwise_add_7", {1, 56, 40, 40}},
          {"nearest_interp_v2_0.tmp_0", {1, 96, 40, 40}}};

      config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
                                    opt_input_shape);
70
    }
littletomatodonkey's avatar
littletomatodonkey committed
71
72
  } else {
    config.DisableGpu();
littletomatodonkey's avatar
littletomatodonkey committed
73
74
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
WenmuZhou's avatar
WenmuZhou committed
75
76
      // cache 10 different shapes for mkldnn to avoid memory leak
      config.SetMkldnnCacheCapacity(10);
littletomatodonkey's avatar
littletomatodonkey committed
77
    }
littletomatodonkey's avatar
littletomatodonkey committed
78
79
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
  }
LDOUBLEV's avatar
LDOUBLEV committed
80
81
  // use zero_copy_run as default
  config.SwitchUseFeedFetchOps(false);
littletomatodonkey's avatar
littletomatodonkey committed
82
  // true for multiple input
littletomatodonkey's avatar
littletomatodonkey committed
83
  config.SwitchSpecifyInputNames(true);
littletomatodonkey's avatar
littletomatodonkey committed
84
85
86
87

  config.SwitchIrOptim(true);

  config.EnableMemoryOptim();
LDOUBLEV's avatar
LDOUBLEV committed
88
  // config.DisableGlogInfo();
littletomatodonkey's avatar
littletomatodonkey committed
89

LDOUBLEV's avatar
LDOUBLEV committed
90
  this->predictor_ = CreatePredictor(config);
littletomatodonkey's avatar
littletomatodonkey committed
91
92
93
94
95
96
97
98
99
100
}

void DBDetector::Run(cv::Mat &img,
                     std::vector<std::vector<std::vector<int>>> &boxes) {
  float ratio_h{};
  float ratio_w{};

  cv::Mat srcimg;
  cv::Mat resize_img;
  img.copyTo(srcimg);
root's avatar
root committed
101
102
  this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
                       this->use_tensorrt_);
littletomatodonkey's avatar
littletomatodonkey committed
103
104
105
106

  this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
                          this->is_scale_);

littletomatodonkey's avatar
littletomatodonkey committed
107
108
  std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
  this->permute_op_.Run(&resize_img, input.data());
littletomatodonkey's avatar
littletomatodonkey committed
109

110
  // Inference.
LDOUBLEV's avatar
LDOUBLEV committed
111
112
113
114
115
  auto input_names = this->predictor_->GetInputNames();
  auto input_t = this->predictor_->GetInputHandle(input_names[0]);
  input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
  input_t->CopyFromCpu(input.data());
  this->predictor_->Run();
littletomatodonkey's avatar
littletomatodonkey committed
116
117
118

  std::vector<float> out_data;
  auto output_names = this->predictor_->GetOutputNames();
LDOUBLEV's avatar
LDOUBLEV committed
119
  auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
littletomatodonkey's avatar
littletomatodonkey committed
120
121
122
123
124
  std::vector<int> output_shape = output_t->shape();
  int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                std::multiplies<int>());

  out_data.resize(out_num);
LDOUBLEV's avatar
LDOUBLEV committed
125
  output_t->CopyToCpu(out_data.data());
littletomatodonkey's avatar
littletomatodonkey committed
126
127
128
129
130

  int n2 = output_shape[2];
  int n3 = output_shape[3];
  int n = n2 * n3;

littletomatodonkey's avatar
littletomatodonkey committed
131
132
  std::vector<float> pred(n, 0.0);
  std::vector<unsigned char> cbuf(n, ' ');
littletomatodonkey's avatar
littletomatodonkey committed
133
134
135
136
137
138

  for (int i = 0; i < n; i++) {
    pred[i] = float(out_data[i]);
    cbuf[i] = (unsigned char)((out_data[i]) * 255);
  }

littletomatodonkey's avatar
littletomatodonkey committed
139
140
  cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
  cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
littletomatodonkey's avatar
littletomatodonkey committed
141

littletomatodonkey's avatar
littletomatodonkey committed
142
  const double threshold = this->det_db_thresh_ * 255;
littletomatodonkey's avatar
littletomatodonkey committed
143
144
145
  const double maxvalue = 255;
  cv::Mat bit_map;
  cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
WenmuZhou's avatar
WenmuZhou committed
146
147
148
  cv::Mat dilation_map;
  cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
  cv::dilate(bit_map, dilation_map, dila_ele);
149
150
151
  boxes = post_processor_.BoxesFromBitmap(
      pred_map, dilation_map, this->det_db_box_thresh_,
      this->det_db_unclip_ratio_, this->use_polygon_score_);
littletomatodonkey's avatar
littletomatodonkey committed
152

littletomatodonkey's avatar
littletomatodonkey committed
153
  boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
MissPenguin's avatar
MissPenguin committed
154
155
  std::cout << "Detected boxes num: " << boxes.size() << endl;
    
littletomatodonkey's avatar
littletomatodonkey committed
156
  //// visualization
littletomatodonkey's avatar
littletomatodonkey committed
157
158
  if (this->visualize_) {
    Utility::VisualizeBboxes(srcimg, boxes);
littletomatodonkey's avatar
littletomatodonkey committed
159
160
161
  }
}

littletomatodonkey's avatar
littletomatodonkey committed
162
} // namespace PaddleOCR