"examples/vscode:/vscode.git/clone" did not exist on "17cece072aec2007e7c3febb99455b66fd485af5"
ocr_det.cpp 7.06 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/ocr_det.h>
MissPenguin's avatar
MissPenguin committed
16

littletomatodonkey's avatar
littletomatodonkey committed
17
18
19

namespace PaddleOCR {

littletomatodonkey's avatar
littletomatodonkey committed
20
void DBDetector::LoadModel(const std::string &model_dir) {
LDOUBLEV's avatar
LDOUBLEV committed
21
22
  //   AnalysisConfig config;
  paddle_infer::Config config;
WenmuZhou's avatar
WenmuZhou committed
23
24
  config.SetModel(model_dir + "/inference.pdmodel",
                  model_dir + "/inference.pdiparams");
littletomatodonkey's avatar
littletomatodonkey committed
25

littletomatodonkey's avatar
littletomatodonkey committed
26
27
  if (this->use_gpu_) {
    config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
28
    if (this->use_tensorrt_) {
MissPenguin's avatar
MissPenguin committed
29
30
31
32
33
34
35
      auto precision = paddle_infer::Config::Precision::kFloat32;
      if (this->precision_ == "fp16") {
        precision = paddle_infer::Config::Precision::kHalf;
      }
     if (this->precision_ == "int8") {
        precision = paddle_infer::Config::Precision::kInt8;
      } 
36
37
      config.EnableTensorRtEngine(
          1 << 20, 10, 3,
MissPenguin's avatar
MissPenguin committed
38
          precision,
39
          false, false);
LDOUBLEV's avatar
LDOUBLEV committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
      std::map<std::string, std::vector<int>> min_input_shape = {
          {"x", {1, 3, 50, 50}},
          {"conv2d_92.tmp_0", {1, 96, 20, 20}},
          {"conv2d_91.tmp_0", {1, 96, 10, 10}},
          {"nearest_interp_v2_1.tmp_0", {1, 96, 10, 10}},
          {"nearest_interp_v2_2.tmp_0", {1, 96, 20, 20}},
          {"nearest_interp_v2_3.tmp_0", {1, 24, 20, 20}},
          {"nearest_interp_v2_4.tmp_0", {1, 24, 20, 20}},
          {"nearest_interp_v2_5.tmp_0", {1, 24, 20, 20}},
          {"elementwise_add_7", {1, 56, 2, 2}},
          {"nearest_interp_v2_0.tmp_0", {1, 96, 2, 2}}};
      std::map<std::string, std::vector<int>> max_input_shape = {
          {"x", {1, 3, this->max_side_len_, this->max_side_len_}},
          {"conv2d_92.tmp_0", {1, 96, 400, 400}},
          {"conv2d_91.tmp_0", {1, 96, 200, 200}},
          {"nearest_interp_v2_1.tmp_0", {1, 96, 200, 200}},
          {"nearest_interp_v2_2.tmp_0", {1, 96, 400, 400}},
          {"nearest_interp_v2_3.tmp_0", {1, 24, 400, 400}},
          {"nearest_interp_v2_4.tmp_0", {1, 24, 400, 400}},
          {"nearest_interp_v2_5.tmp_0", {1, 24, 400, 400}},
          {"elementwise_add_7", {1, 56, 400, 400}},
          {"nearest_interp_v2_0.tmp_0", {1, 96, 400, 400}}};
      std::map<std::string, std::vector<int>> opt_input_shape = {
          {"x", {1, 3, 640, 640}},
          {"conv2d_92.tmp_0", {1, 96, 160, 160}},
          {"conv2d_91.tmp_0", {1, 96, 80, 80}},
          {"nearest_interp_v2_1.tmp_0", {1, 96, 80, 80}},
          {"nearest_interp_v2_2.tmp_0", {1, 96, 160, 160}},
          {"nearest_interp_v2_3.tmp_0", {1, 24, 160, 160}},
          {"nearest_interp_v2_4.tmp_0", {1, 24, 160, 160}},
          {"nearest_interp_v2_5.tmp_0", {1, 24, 160, 160}},
          {"elementwise_add_7", {1, 56, 40, 40}},
          {"nearest_interp_v2_0.tmp_0", {1, 96, 40, 40}}};

      config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
                                    opt_input_shape);
76
    }
littletomatodonkey's avatar
littletomatodonkey committed
77
78
  } else {
    config.DisableGpu();
littletomatodonkey's avatar
littletomatodonkey committed
79
80
    if (this->use_mkldnn_) {
      config.EnableMKLDNN();
WenmuZhou's avatar
WenmuZhou committed
81
82
      // cache 10 different shapes for mkldnn to avoid memory leak
      config.SetMkldnnCacheCapacity(10);
littletomatodonkey's avatar
littletomatodonkey committed
83
    }
littletomatodonkey's avatar
littletomatodonkey committed
84
85
    config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
  }
LDOUBLEV's avatar
LDOUBLEV committed
86
87
  // use zero_copy_run as default
  config.SwitchUseFeedFetchOps(false);
littletomatodonkey's avatar
littletomatodonkey committed
88
  // true for multiple input
littletomatodonkey's avatar
littletomatodonkey committed
89
  config.SwitchSpecifyInputNames(true);
littletomatodonkey's avatar
littletomatodonkey committed
90
91
92
93

  config.SwitchIrOptim(true);

  config.EnableMemoryOptim();
LDOUBLEV's avatar
LDOUBLEV committed
94
  // config.DisableGlogInfo();
littletomatodonkey's avatar
littletomatodonkey committed
95

LDOUBLEV's avatar
LDOUBLEV committed
96
  this->predictor_ = CreatePredictor(config);
littletomatodonkey's avatar
littletomatodonkey committed
97
98
99
}

void DBDetector::Run(cv::Mat &img,
MissPenguin's avatar
MissPenguin committed
100
101
                     std::vector<std::vector<std::vector<int>>> &boxes,
                     std::vector<double> *times) {
littletomatodonkey's avatar
littletomatodonkey committed
102
103
104
105
106
107
  float ratio_h{};
  float ratio_w{};

  cv::Mat srcimg;
  cv::Mat resize_img;
  img.copyTo(srcimg);
MissPenguin's avatar
MissPenguin committed
108
109
  
  auto preprocess_start = std::chrono::steady_clock::now();
root's avatar
root committed
110
111
  this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
                       this->use_tensorrt_);
littletomatodonkey's avatar
littletomatodonkey committed
112
113
114
115

  this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
                          this->is_scale_);

littletomatodonkey's avatar
littletomatodonkey committed
116
117
  std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
  this->permute_op_.Run(&resize_img, input.data());
MissPenguin's avatar
MissPenguin committed
118
119
  auto preprocess_end = std::chrono::steady_clock::now();
    
120
  // Inference.
LDOUBLEV's avatar
LDOUBLEV committed
121
122
123
  auto input_names = this->predictor_->GetInputNames();
  auto input_t = this->predictor_->GetInputHandle(input_names[0]);
  input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
MissPenguin's avatar
update  
MissPenguin committed
124
  auto inference_start = std::chrono::steady_clock::now();
LDOUBLEV's avatar
LDOUBLEV committed
125
  input_t->CopyFromCpu(input.data());
MissPenguin's avatar
MissPenguin committed
126
  
LDOUBLEV's avatar
LDOUBLEV committed
127
  this->predictor_->Run();
MissPenguin's avatar
MissPenguin committed
128
    
littletomatodonkey's avatar
littletomatodonkey committed
129
130
  std::vector<float> out_data;
  auto output_names = this->predictor_->GetOutputNames();
LDOUBLEV's avatar
LDOUBLEV committed
131
  auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
littletomatodonkey's avatar
littletomatodonkey committed
132
133
134
135
136
  std::vector<int> output_shape = output_t->shape();
  int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                std::multiplies<int>());

  out_data.resize(out_num);
LDOUBLEV's avatar
LDOUBLEV committed
137
  output_t->CopyToCpu(out_data.data());
MissPenguin's avatar
MissPenguin committed
138
139
140
  auto inference_end = std::chrono::steady_clock::now();
  
  auto postprocess_start = std::chrono::steady_clock::now();
littletomatodonkey's avatar
littletomatodonkey committed
141
142
143
144
  int n2 = output_shape[2];
  int n3 = output_shape[3];
  int n = n2 * n3;

littletomatodonkey's avatar
littletomatodonkey committed
145
146
  std::vector<float> pred(n, 0.0);
  std::vector<unsigned char> cbuf(n, ' ');
littletomatodonkey's avatar
littletomatodonkey committed
147
148
149
150
151
152

  for (int i = 0; i < n; i++) {
    pred[i] = float(out_data[i]);
    cbuf[i] = (unsigned char)((out_data[i]) * 255);
  }

littletomatodonkey's avatar
littletomatodonkey committed
153
154
  cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
  cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
littletomatodonkey's avatar
littletomatodonkey committed
155

littletomatodonkey's avatar
littletomatodonkey committed
156
  const double threshold = this->det_db_thresh_ * 255;
littletomatodonkey's avatar
littletomatodonkey committed
157
158
159
  const double maxvalue = 255;
  cv::Mat bit_map;
  cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
WenmuZhou's avatar
WenmuZhou committed
160
161
162
  cv::Mat dilation_map;
  cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
  cv::dilate(bit_map, dilation_map, dila_ele);
163
164
165
  boxes = post_processor_.BoxesFromBitmap(
      pred_map, dilation_map, this->det_db_box_thresh_,
      this->det_db_unclip_ratio_, this->use_polygon_score_);
littletomatodonkey's avatar
littletomatodonkey committed
166

littletomatodonkey's avatar
littletomatodonkey committed
167
  boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
MissPenguin's avatar
MissPenguin committed
168
  auto postprocess_end = std::chrono::steady_clock::now();
MissPenguin's avatar
update  
MissPenguin committed
169
  std::cout << "Detected boxes num: " << boxes.size() << endl;
MissPenguin's avatar
MissPenguin committed
170
171
172
173
174
175
176

  std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
  times->push_back(double(preprocess_diff.count() * 1000));
  std::chrono::duration<float> inference_diff = inference_end - inference_start;
  times->push_back(double(inference_diff.count() * 1000));
  std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
  times->push_back(double(postprocess_diff.count() * 1000));
MissPenguin's avatar
MissPenguin committed
177
    
littletomatodonkey's avatar
littletomatodonkey committed
178
  //// visualization
littletomatodonkey's avatar
littletomatodonkey committed
179
180
  if (this->visualize_) {
    Utility::VisualizeBboxes(srcimg, boxes);
littletomatodonkey's avatar
littletomatodonkey committed
181
182
183
  }
}

littletomatodonkey's avatar
littletomatodonkey committed
184
} // namespace PaddleOCR